Skip to content

Commit a686768

Browse files
Round tip data test for enums
1 parent b32070a commit a686768

File tree

1 file changed

+102
-8
lines changed

1 file changed

+102
-8
lines changed

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.apache.arrow.vector.TimeStampMilliVector;
5353
import org.apache.arrow.vector.TimeStampNanoTZVector;
5454
import org.apache.arrow.vector.TimeStampNanoVector;
55+
import org.apache.arrow.vector.TinyIntVector;
5556
import org.apache.arrow.vector.VarBinaryVector;
5657
import org.apache.arrow.vector.VarCharVector;
5758
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -60,10 +61,14 @@
6061
import org.apache.arrow.vector.complex.StructVector;
6162
import org.apache.arrow.vector.complex.writer.BaseWriter;
6263
import org.apache.arrow.vector.complex.writer.FieldWriter;
64+
import org.apache.arrow.vector.dictionary.Dictionary;
65+
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
66+
import org.apache.arrow.vector.dictionary.DictionaryProvider;
6367
import org.apache.arrow.vector.types.DateUnit;
6468
import org.apache.arrow.vector.types.FloatingPointPrecision;
6569
import org.apache.arrow.vector.types.TimeUnit;
6670
import org.apache.arrow.vector.types.pojo.ArrowType;
71+
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
6772
import org.apache.arrow.vector.types.pojo.Field;
6873
import org.apache.arrow.vector.types.pojo.FieldType;
6974
import org.apache.avro.Schema;
@@ -78,39 +83,59 @@ public class RoundTripDataTest {
7883

7984
@TempDir public static File TMP;
8085

81-
private static AvroToArrowConfig basicConfig(BufferAllocator allocator) {
82-
return new AvroToArrowConfig(allocator, 1000, null, Collections.emptySet(), false);
86+
private static AvroToArrowConfig basicConfig(
87+
BufferAllocator allocator, DictionaryProvider.MapDictionaryProvider dictionaries) {
88+
return new AvroToArrowConfig(allocator, 1000, dictionaries, Collections.emptySet(), false);
8389
}
8490

8591
private static VectorSchemaRoot readDataFile(
86-
Schema schema, File dataFile, BufferAllocator allocator) throws Exception {
92+
Schema schema,
93+
File dataFile,
94+
BufferAllocator allocator,
95+
DictionaryProvider.MapDictionaryProvider dictionaries)
96+
throws Exception {
8797

8898
try (FileInputStream fis = new FileInputStream(dataFile)) {
8999
BinaryDecoder decoder = new DecoderFactory().directBinaryDecoder(fis, null);
90-
return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator));
100+
return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator, dictionaries));
91101
}
92102
}
93103

94104
private static void roundTripTest(
95105
VectorSchemaRoot root, BufferAllocator allocator, File dataFile, int rowCount)
96106
throws Exception {
97107

108+
roundTripTest(root, allocator, dataFile, rowCount, null);
109+
}
110+
111+
private static void roundTripTest(
112+
VectorSchemaRoot root,
113+
BufferAllocator allocator,
114+
File dataFile,
115+
int rowCount,
116+
DictionaryProvider dictionaries)
117+
throws Exception {
118+
98119
// Write an AVRO block using the producer classes
99120
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
100121
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
101122
CompositeAvroProducer producer =
102-
ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors());
123+
ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors(), dictionaries);
103124
for (int row = 0; row < rowCount; row++) {
104125
producer.produce(encoder);
105126
}
106127
encoder.flush();
107128
}
108129

109130
// Generate AVRO schema
110-
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields());
131+
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
132+
133+
DictionaryProvider.MapDictionaryProvider roundTripDictionaries =
134+
new DictionaryProvider.MapDictionaryProvider();
111135

112136
// Read back in and compare
113-
try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) {
137+
try (VectorSchemaRoot roundTrip =
138+
readDataFile(schema, dataFile, allocator, roundTripDictionaries)) {
114139

115140
assertEquals(root.getSchema(), roundTrip.getSchema());
116141
assertEquals(rowCount, roundTrip.getRowCount());
@@ -119,6 +144,21 @@ private static void roundTripTest(
119144
for (int row = 0; row < rowCount; row++) {
120145
assertEquals(root.getVector(0).getObject(row), roundTrip.getVector(0).getObject(row));
121146
}
147+
148+
if (dictionaries != null) {
149+
for (long id : dictionaries.getDictionaryIds()) {
150+
Dictionary originalDictionary = dictionaries.lookup(id);
151+
Dictionary roundTripDictionary = roundTripDictionaries.lookup(id);
152+
assertEquals(
153+
originalDictionary.getVector().getValueCount(),
154+
roundTripDictionary.getVector().getValueCount());
155+
for (int j = 0; j < originalDictionary.getVector().getValueCount(); j++) {
156+
assertEquals(
157+
originalDictionary.getVector().getObject(j),
158+
roundTripDictionary.getVector().getObject(j));
159+
}
160+
}
161+
}
122162
}
123163
}
124164

@@ -141,7 +181,7 @@ private static void roundTripByteArrayTest(
141181
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields());
142182

143183
// Read back in and compare
144-
try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) {
184+
try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator, null)) {
145185

146186
assertEquals(root.getSchema(), roundTrip.getSchema());
147187
assertEquals(rowCount, roundTrip.getRowCount());
@@ -1603,4 +1643,58 @@ public void testRoundTripNullableStructs() throws Exception {
16031643
roundTripTest(root, allocator, dataFile, rowCount);
16041644
}
16051645
}
1646+
1647+
@Test
1648+
public void testRoundTripEnum() throws Exception {
1649+
1650+
BufferAllocator allocator = new RootAllocator();
1651+
1652+
// Create a dictionary
1653+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
1654+
VarCharVector dictionaryVector =
1655+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
1656+
1657+
dictionaryVector.allocateNew(3);
1658+
dictionaryVector.set(0, "apple".getBytes());
1659+
dictionaryVector.set(1, "banana".getBytes());
1660+
dictionaryVector.set(2, "cherry".getBytes());
1661+
dictionaryVector.setValueCount(3);
1662+
1663+
// For simplicity, ensure the index type matches what will be decoded during Avro enum decoding
1664+
Dictionary dictionary =
1665+
new Dictionary(
1666+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1667+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1668+
1669+
// Field definition
1670+
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
1671+
VarCharVector stringVector =
1672+
new VarCharVector(new Field("enumField", stringField, null), allocator);
1673+
stringVector.allocateNew(10);
1674+
stringVector.setSafe(0, "apple".getBytes());
1675+
stringVector.setSafe(1, "banana".getBytes());
1676+
stringVector.setSafe(2, "cherry".getBytes());
1677+
stringVector.setSafe(3, "cherry".getBytes());
1678+
stringVector.setSafe(4, "apple".getBytes());
1679+
stringVector.setSafe(5, "banana".getBytes());
1680+
stringVector.setSafe(6, "apple".getBytes());
1681+
stringVector.setSafe(7, "cherry".getBytes());
1682+
stringVector.setSafe(8, "banana".getBytes());
1683+
stringVector.setSafe(9, "apple".getBytes());
1684+
stringVector.setValueCount(10);
1685+
1686+
TinyIntVector encodedVector =
1687+
(TinyIntVector) DictionaryEncoder.encode(stringVector, dictionary);
1688+
1689+
// Set up VSR
1690+
List<FieldVector> vectors = Arrays.asList(encodedVector);
1691+
int rowCount = 10;
1692+
1693+
try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
1694+
1695+
File dataFile = new File(TMP, "testRoundTripEnums.avro");
1696+
1697+
roundTripTest(root, allocator, dataFile, rowCount, dictionaries);
1698+
}
1699+
}
16061700
}

0 commit comments

Comments
 (0)