|
21 | 21 | import java.util.Arrays; |
22 | 22 | import java.util.Collections; |
23 | 23 | import java.util.List; |
| 24 | +import org.apache.arrow.memory.BufferAllocator; |
| 25 | +import org.apache.arrow.memory.RootAllocator; |
| 26 | +import org.apache.arrow.vector.VarCharVector; |
| 27 | +import org.apache.arrow.vector.dictionary.Dictionary; |
| 28 | +import org.apache.arrow.vector.dictionary.DictionaryProvider; |
24 | 29 | import org.apache.arrow.vector.types.DateUnit; |
25 | 30 | import org.apache.arrow.vector.types.FloatingPointPrecision; |
26 | 31 | import org.apache.arrow.vector.types.TimeUnit; |
27 | 32 | import org.apache.arrow.vector.types.pojo.ArrowType; |
| 33 | +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; |
28 | 34 | import org.apache.arrow.vector.types.pojo.Field; |
29 | 35 | import org.apache.arrow.vector.types.pojo.FieldType; |
30 | 36 | import org.apache.avro.Schema; |
| 37 | +import org.junit.jupiter.api.Assertions; |
31 | 38 | import org.junit.jupiter.api.Test; |
32 | 39 |
|
33 | 40 | public class RoundTripSchemaTest { |
34 | 41 |
|
35 | 42 | private void doRoundTripTest(List<Field> fields) { |
| 43 | + doRoundTripTest(fields, null); |
| 44 | + } |
36 | 45 |
|
37 | | - AvroToArrowConfig config = new AvroToArrowConfig(null, 1, null, Collections.emptySet(), false); |
| 46 | + private void doRoundTripTest(List<Field> fields, DictionaryProvider dictionaries) { |
38 | 47 |
|
39 | | - Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord"); |
| 48 | + DictionaryProvider.MapDictionaryProvider decodeDictionaries = |
| 49 | + new DictionaryProvider.MapDictionaryProvider(); |
| 50 | + AvroToArrowConfig decodeConfig = |
| 51 | + new AvroToArrowConfig(null, 1, decodeDictionaries, Collections.emptySet(), false); |
| 52 | + |
| 53 | + Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); |
40 | 54 | org.apache.arrow.vector.types.pojo.Schema arrowSchema = |
41 | | - AvroToArrowUtils.createArrowSchema(avroSchema, config); |
| 55 | + AvroToArrowUtils.createArrowSchema(avroSchema, decodeConfig); |
42 | 56 |
|
43 | 57 | // Compare string representations - equality not defined for logical types |
44 | 58 | assertEquals(fields, arrowSchema.getFields()); |
| 59 | + |
| 60 | + for (int i = 0; i < fields.size(); i++) { |
| 61 | + Field field = fields.get(i); |
| 62 | + Field rtField = arrowSchema.getFields().get(i); |
| 63 | + if (field.getDictionary() != null) { |
| 64 | + // Dictionary content is not decoded until the data is consumed |
| 65 | + Assertions.assertNotNull(rtField.getDictionary()); |
| 66 | + } |
| 67 | + } |
45 | 68 | } |
46 | 69 |
|
47 | 70 | // Schema round trip for primitive types, nullable and non-nullable |
@@ -440,4 +463,38 @@ public void testRoundTripStructType() { |
440 | 463 |
|
441 | 464 | doRoundTripTest(fields); |
442 | 465 | } |
| 466 | + |
| 467 | + @Test |
| 468 | + public void testRoundTripEnumType() { |
| 469 | + |
| 470 | + BufferAllocator allocator = new RootAllocator(); |
| 471 | + |
| 472 | + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); |
| 473 | + VarCharVector dictionaryVector = |
| 474 | + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); |
| 475 | + |
| 476 | + dictionaryVector.allocateNew(3); |
| 477 | + dictionaryVector.set(0, "apple".getBytes()); |
| 478 | + dictionaryVector.set(1, "banana".getBytes()); |
| 479 | + dictionaryVector.set(2, "cherry".getBytes()); |
| 480 | + dictionaryVector.setValueCount(3); |
| 481 | + |
| 482 | + // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding |
| 483 | + Dictionary dictionary = |
| 484 | + new Dictionary( |
| 485 | + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); |
| 486 | + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); |
| 487 | + |
| 488 | + List<Field> fields = |
| 489 | + Arrays.asList( |
| 490 | + new Field( |
| 491 | + "enumField", |
| 492 | + new FieldType( |
| 493 | + true, |
| 494 | + new ArrowType.Int(8, true), |
| 495 | + new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))), |
| 496 | + null)); |
| 497 | + |
| 498 | + doRoundTripTest(fields, dictionaries); |
| 499 | + } |
443 | 500 | } |
0 commit comments