Skip to content

Commit 64d5b7a

Browse files
Avro write test for enums
1 parent a686768 commit 64d5b7a

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,14 @@
7676
import org.apache.arrow.vector.complex.StructVector;
7777
import org.apache.arrow.vector.complex.writer.BaseWriter;
7878
import org.apache.arrow.vector.complex.writer.FieldWriter;
79+
import org.apache.arrow.vector.dictionary.Dictionary;
80+
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
81+
import org.apache.arrow.vector.dictionary.DictionaryProvider;
7982
import org.apache.arrow.vector.types.DateUnit;
8083
import org.apache.arrow.vector.types.FloatingPointPrecision;
8184
import org.apache.arrow.vector.types.TimeUnit;
8285
import org.apache.arrow.vector.types.pojo.ArrowType;
86+
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
8387
import org.apache.arrow.vector.types.pojo.Field;
8488
import org.apache.arrow.vector.types.pojo.FieldType;
8589
import org.apache.arrow.vector.util.JsonStringArrayList;
@@ -2817,4 +2821,184 @@ record = datumReader.read(record, decoder);
28172821
}
28182822
}
28192823
}
2824+
2825+
@Test
2826+
public void testWriteDictEnumEncoded() throws Exception {
2827+
2828+
BufferAllocator allocator = new RootAllocator();
2829+
2830+
// Create a dictionary
2831+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
2832+
VarCharVector dictionaryVector =
2833+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
2834+
2835+
dictionaryVector.allocateNew(3);
2836+
dictionaryVector.set(0, "apple".getBytes());
2837+
dictionaryVector.set(1, "banana".getBytes());
2838+
dictionaryVector.set(2, "cherry".getBytes());
2839+
dictionaryVector.setValueCount(3);
2840+
2841+
Dictionary dictionary =
2842+
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
2843+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
2844+
2845+
// Field definition
2846+
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
2847+
VarCharVector stringVector =
2848+
new VarCharVector(new Field("enumField", stringField, null), allocator);
2849+
stringVector.allocateNew(10);
2850+
stringVector.setSafe(0, "apple".getBytes());
2851+
stringVector.setSafe(1, "banana".getBytes());
2852+
stringVector.setSafe(2, "cherry".getBytes());
2853+
stringVector.setSafe(3, "cherry".getBytes());
2854+
stringVector.setSafe(4, "apple".getBytes());
2855+
stringVector.setSafe(5, "banana".getBytes());
2856+
stringVector.setSafe(6, "apple".getBytes());
2857+
stringVector.setSafe(7, "cherry".getBytes());
2858+
stringVector.setSafe(8, "banana".getBytes());
2859+
stringVector.setSafe(9, "apple".getBytes());
2860+
stringVector.setValueCount(10);
2861+
2862+
IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);
2863+
2864+
// Set up VSR
2865+
List<FieldVector> vectors = Arrays.asList(encodedVector);
2866+
int rowCount = 10;
2867+
2868+
try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
2869+
2870+
File dataFile = new File(TMP, "testWriteEnumEncoded.avro");
2871+
2872+
// Write an AVRO block using the producer classes
2873+
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
2874+
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
2875+
CompositeAvroProducer producer =
2876+
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
2877+
for (int row = 0; row < rowCount; row++) {
2878+
producer.produce(encoder);
2879+
}
2880+
encoder.flush();
2881+
}
2882+
2883+
// Set up reading the AVRO block as a GenericRecord
2884+
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
2885+
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema);
2886+
2887+
try (InputStream inputStream = new FileInputStream(dataFile)) {
2888+
2889+
BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
2890+
GenericRecord record = null;
2891+
2892+
// Read and check values
2893+
for (int row = 0; row < rowCount; row++) {
2894+
record = datumReader.read(record, decoder);
2895+
// Values read from Avro should be the decoded enum values
2896+
assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
2897+
}
2898+
}
2899+
}
2900+
}
2901+
2902+
@Test
2903+
public void testWriteEnumDecoded() throws Exception {
2904+
2905+
// Dict encoded fields that are not valid Avro enums should be decoded on write
2906+
2907+
BufferAllocator allocator = new RootAllocator();
2908+
2909+
// Create a dictionary
2910+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
2911+
VarCharVector dictionaryVector =
2912+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
2913+
2914+
dictionaryVector.allocateNew(3);
2915+
dictionaryVector.set(0, "passion fruit".getBytes()); // spaced not allowed
2916+
dictionaryVector.set(1, "banana".getBytes());
2917+
dictionaryVector.set(2, "cherry".getBytes());
2918+
dictionaryVector.setValueCount(3);
2919+
2920+
Dictionary dictionary =
2921+
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
2922+
2923+
FieldType dictionaryField2 = new FieldType(false, new ArrowType.Int(64, true), null);
2924+
BigIntVector dictionaryVector2 =
2925+
new BigIntVector(new Field("dictionary2", dictionaryField2, null), allocator);
2926+
2927+
dictionaryVector2.allocateNew(3);
2928+
dictionaryVector2.set(0, 0L);
2929+
dictionaryVector2.set(1, 1L);
2930+
dictionaryVector2.set(2, 2L);
2931+
dictionaryVector2.setValueCount(3);
2932+
2933+
Dictionary dictionary2 =
2934+
new Dictionary(dictionaryVector2, new DictionaryEncoding(2L, false, null));
2935+
2936+
DictionaryProvider dictionaries =
2937+
new DictionaryProvider.MapDictionaryProvider(dictionary, dictionary2);
2938+
2939+
// Field definition
2940+
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
2941+
VarCharVector stringVector =
2942+
new VarCharVector(new Field("enumField", stringField, null), allocator);
2943+
stringVector.allocateNew(10);
2944+
stringVector.setSafe(0, "passion fruit".getBytes());
2945+
stringVector.setSafe(1, "banana".getBytes());
2946+
stringVector.setSafe(2, "cherry".getBytes());
2947+
stringVector.setSafe(3, "cherry".getBytes());
2948+
stringVector.setSafe(4, "passion fruit".getBytes());
2949+
stringVector.setSafe(5, "banana".getBytes());
2950+
stringVector.setSafe(6, "passion fruit".getBytes());
2951+
stringVector.setSafe(7, "cherry".getBytes());
2952+
stringVector.setSafe(8, "banana".getBytes());
2953+
stringVector.setSafe(9, "passion fruit".getBytes());
2954+
stringVector.setValueCount(10);
2955+
2956+
FieldType longField = new FieldType(false, new ArrowType.Int(64, true), null);
2957+
BigIntVector longVector = new BigIntVector(new Field("enumField2", longField, null), allocator);
2958+
longVector.allocateNew(10);
2959+
for (int i = 0; i < 10; i++) {
2960+
longVector.setSafe(i, (long) i % 3);
2961+
}
2962+
longVector.setValueCount(10);
2963+
2964+
IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);
2965+
IntVector encodedVector2 = (IntVector) DictionaryEncoder.encode(longVector, dictionary2);
2966+
2967+
// Set up VSR
2968+
List<FieldVector> vectors = Arrays.asList(encodedVector, encodedVector2);
2969+
int rowCount = 10;
2970+
2971+
try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
2972+
2973+
File dataFile = new File(TMP, "testWriteEnumDecodedavro");
2974+
2975+
// Write an AVRO block using the producer classes
2976+
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
2977+
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
2978+
CompositeAvroProducer producer =
2979+
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
2980+
for (int row = 0; row < rowCount; row++) {
2981+
producer.produce(encoder);
2982+
}
2983+
encoder.flush();
2984+
}
2985+
2986+
// Set up reading the AVRO block as a GenericRecord
2987+
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
2988+
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema);
2989+
2990+
try (InputStream inputStream = new FileInputStream(dataFile)) {
2991+
2992+
BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
2993+
GenericRecord record = null;
2994+
2995+
// Read and check values
2996+
for (int row = 0; row < rowCount; row++) {
2997+
record = datumReader.read(record, decoder);
2998+
assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
2999+
assertEquals(longVector.getObject(row), record.get("enumField2"));
3000+
}
3001+
}
3002+
}
3003+
}
28203004
}

0 commit comments

Comments
 (0)