5252import org .apache .arrow .vector .TimeStampMilliVector ;
5353import org .apache .arrow .vector .TimeStampNanoTZVector ;
5454import org .apache .arrow .vector .TimeStampNanoVector ;
55+ import org .apache .arrow .vector .TinyIntVector ;
5556import org .apache .arrow .vector .VarBinaryVector ;
5657import org .apache .arrow .vector .VarCharVector ;
5758import org .apache .arrow .vector .VectorSchemaRoot ;
6061import org .apache .arrow .vector .complex .StructVector ;
6162import org .apache .arrow .vector .complex .writer .BaseWriter ;
6263import org .apache .arrow .vector .complex .writer .FieldWriter ;
64+ import org .apache .arrow .vector .dictionary .Dictionary ;
65+ import org .apache .arrow .vector .dictionary .DictionaryEncoder ;
66+ import org .apache .arrow .vector .dictionary .DictionaryProvider ;
6367import org .apache .arrow .vector .types .DateUnit ;
6468import org .apache .arrow .vector .types .FloatingPointPrecision ;
6569import org .apache .arrow .vector .types .TimeUnit ;
6670import org .apache .arrow .vector .types .pojo .ArrowType ;
71+ import org .apache .arrow .vector .types .pojo .DictionaryEncoding ;
6772import org .apache .arrow .vector .types .pojo .Field ;
6873import org .apache .arrow .vector .types .pojo .FieldType ;
6974import org .apache .avro .Schema ;
@@ -78,39 +83,59 @@ public class RoundTripDataTest {
7883
7984 @ TempDir public static File TMP ;
8085
81- private static AvroToArrowConfig basicConfig (BufferAllocator allocator ) {
82- return new AvroToArrowConfig (allocator , 1000 , null , Collections .emptySet (), false );
86+ private static AvroToArrowConfig basicConfig (
87+ BufferAllocator allocator , DictionaryProvider .MapDictionaryProvider dictionaries ) {
88+ return new AvroToArrowConfig (allocator , 1000 , dictionaries , Collections .emptySet (), false );
8389 }
8490
8591 private static VectorSchemaRoot readDataFile (
86- Schema schema , File dataFile , BufferAllocator allocator ) throws Exception {
92+ Schema schema ,
93+ File dataFile ,
94+ BufferAllocator allocator ,
95+ DictionaryProvider .MapDictionaryProvider dictionaries )
96+ throws Exception {
8797
8898 try (FileInputStream fis = new FileInputStream (dataFile )) {
8999 BinaryDecoder decoder = new DecoderFactory ().directBinaryDecoder (fis , null );
90- return AvroToArrow .avroToArrow (schema , decoder , basicConfig (allocator ));
100+ return AvroToArrow .avroToArrow (schema , decoder , basicConfig (allocator , dictionaries ));
91101 }
92102 }
93103
94104 private static void roundTripTest (
95105 VectorSchemaRoot root , BufferAllocator allocator , File dataFile , int rowCount )
96106 throws Exception {
97107
108+ roundTripTest (root , allocator , dataFile , rowCount , null );
109+ }
110+
111+ private static void roundTripTest (
112+ VectorSchemaRoot root ,
113+ BufferAllocator allocator ,
114+ File dataFile ,
115+ int rowCount ,
116+ DictionaryProvider dictionaries )
117+ throws Exception {
118+
98119 // Write an AVRO block using the producer classes
99120 try (FileOutputStream fos = new FileOutputStream (dataFile )) {
100121 BinaryEncoder encoder = new EncoderFactory ().directBinaryEncoder (fos , null );
101122 CompositeAvroProducer producer =
102- ArrowToAvroUtils .createCompositeProducer (root .getFieldVectors ());
123+ ArrowToAvroUtils .createCompositeProducer (root .getFieldVectors (), dictionaries );
103124 for (int row = 0 ; row < rowCount ; row ++) {
104125 producer .produce (encoder );
105126 }
106127 encoder .flush ();
107128 }
108129
109130 // Generate AVRO schema
110- Schema schema = ArrowToAvroUtils .createAvroSchema (root .getSchema ().getFields ());
131+ Schema schema = ArrowToAvroUtils .createAvroSchema (root .getSchema ().getFields (), dictionaries );
132+
133+ DictionaryProvider .MapDictionaryProvider roundTripDictionaries =
134+ new DictionaryProvider .MapDictionaryProvider ();
111135
112136 // Read back in and compare
113- try (VectorSchemaRoot roundTrip = readDataFile (schema , dataFile , allocator )) {
137+ try (VectorSchemaRoot roundTrip =
138+ readDataFile (schema , dataFile , allocator , roundTripDictionaries )) {
114139
115140 assertEquals (root .getSchema (), roundTrip .getSchema ());
116141 assertEquals (rowCount , roundTrip .getRowCount ());
@@ -119,6 +144,21 @@ private static void roundTripTest(
119144 for (int row = 0 ; row < rowCount ; row ++) {
120145 assertEquals (root .getVector (0 ).getObject (row ), roundTrip .getVector (0 ).getObject (row ));
121146 }
147+
148+ if (dictionaries != null ) {
149+ for (long id : dictionaries .getDictionaryIds ()) {
150+ Dictionary originalDictionary = dictionaries .lookup (id );
151+ Dictionary roundTripDictionary = roundTripDictionaries .lookup (id );
152+ assertEquals (
153+ originalDictionary .getVector ().getValueCount (),
154+ roundTripDictionary .getVector ().getValueCount ());
155+ for (int j = 0 ; j < originalDictionary .getVector ().getValueCount (); j ++) {
156+ assertEquals (
157+ originalDictionary .getVector ().getObject (j ),
158+ roundTripDictionary .getVector ().getObject (j ));
159+ }
160+ }
161+ }
122162 }
123163 }
124164
@@ -141,7 +181,7 @@ private static void roundTripByteArrayTest(
141181 Schema schema = ArrowToAvroUtils .createAvroSchema (root .getSchema ().getFields ());
142182
143183 // Read back in and compare
144- try (VectorSchemaRoot roundTrip = readDataFile (schema , dataFile , allocator )) {
184+ try (VectorSchemaRoot roundTrip = readDataFile (schema , dataFile , allocator , null )) {
145185
146186 assertEquals (root .getSchema (), roundTrip .getSchema ());
147187 assertEquals (rowCount , roundTrip .getRowCount ());
@@ -1603,4 +1643,58 @@ public void testRoundTripNullableStructs() throws Exception {
16031643 roundTripTest (root , allocator , dataFile , rowCount );
16041644 }
16051645 }
1646+
1647+ @ Test
1648+ public void testRoundTripEnum () throws Exception {
1649+
1650+ BufferAllocator allocator = new RootAllocator ();
1651+
1652+ // Create a dictionary
1653+ FieldType dictionaryField = new FieldType (false , new ArrowType .Utf8 (), null );
1654+ VarCharVector dictionaryVector =
1655+ new VarCharVector (new Field ("dictionary" , dictionaryField , null ), allocator );
1656+
1657+ dictionaryVector .allocateNew (3 );
1658+ dictionaryVector .set (0 , "apple" .getBytes ());
1659+ dictionaryVector .set (1 , "banana" .getBytes ());
1660+ dictionaryVector .set (2 , "cherry" .getBytes ());
1661+ dictionaryVector .setValueCount (3 );
1662+
1663+ // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding
1664+ Dictionary dictionary =
1665+ new Dictionary (
1666+ dictionaryVector , new DictionaryEncoding (0L , false , new ArrowType .Int (8 , true )));
1667+ DictionaryProvider dictionaries = new DictionaryProvider .MapDictionaryProvider (dictionary );
1668+
1669+ // Field definition
1670+ FieldType stringField = new FieldType (false , new ArrowType .Utf8 (), null );
1671+ VarCharVector stringVector =
1672+ new VarCharVector (new Field ("enumField" , stringField , null ), allocator );
1673+ stringVector .allocateNew (10 );
1674+ stringVector .setSafe (0 , "apple" .getBytes ());
1675+ stringVector .setSafe (1 , "banana" .getBytes ());
1676+ stringVector .setSafe (2 , "cherry" .getBytes ());
1677+ stringVector .setSafe (3 , "cherry" .getBytes ());
1678+ stringVector .setSafe (4 , "apple" .getBytes ());
1679+ stringVector .setSafe (5 , "banana" .getBytes ());
1680+ stringVector .setSafe (6 , "apple" .getBytes ());
1681+ stringVector .setSafe (7 , "cherry" .getBytes ());
1682+ stringVector .setSafe (8 , "banana" .getBytes ());
1683+ stringVector .setSafe (9 , "apple" .getBytes ());
1684+ stringVector .setValueCount (10 );
1685+
1686+ TinyIntVector encodedVector =
1687+ (TinyIntVector ) DictionaryEncoder .encode (stringVector , dictionary );
1688+
1689+ // Set up VSR
1690+ List <FieldVector > vectors = Arrays .asList (encodedVector );
1691+ int rowCount = 10 ;
1692+
1693+ try (VectorSchemaRoot root = new VectorSchemaRoot (vectors )) {
1694+
1695+ File dataFile = new File (TMP , "testRoundTripEnums.avro" );
1696+
1697+ roundTripTest (root , allocator , dataFile , rowCount , dictionaries );
1698+ }
1699+ }
16061700}
0 commit comments