Skip to content

Commit 45295a5

Browse files
GH-731: Avro adapter, output dictionary-encoded fields as enums (#779)
## What's Changed Updated ArrowToAvro to output dictionary-encoded string vectors as Avro enums, where possible. Apologies for the delay - busy as usual! To output dict encoded vectors as enums, a dictionary provider must be supplied to the top level methods with all the required dictionaries. All dictionary values must be present when the schema is written, i.e. before the data blocks are produced. If data is being written as a schema followed by multiple blocks, values added to a dictionary in between blocks will not be included in the schema resulting in an invalid Avro file (in general supply an invalid dictionary mapping will result in invalid output). Dictionary encoded fields are checked to ensure they are valid Avro enums. If the dictionary encoded field is not a string field, or the string values are not valid Avro enums, the field is decoded and output as literal values. This is done by calling DictionaryEncoder.decode(vector, dictionary), which will consume memory for the vector. An alternative approach would be to decode values one-by-one, however this would require a significant change to the producer pattern since the current producers expect concrete vectors of the output type. Another option would be to throw an error if there are dictionary-encoded vectors that are not string types, i.e. push the responsibility onto client code. I'm not sure which approach is best - happy to take any guidance and I will update the code accordingly. To read enums back the current approach for decoding is unchanged (the AvroToArrow config has to be set up with a MapDictionaryProvider which is populated when data is read). The last part of the Avro work is to add the capability for reading / writing whole files block-by-block, so there is an opportunity to do something with the top level APIs there, for now the current API works and I've used it in the round trip tests. Please let me know any feedback, happy to update as needed! Closes #731.
1 parent 43b6b6c commit 45295a5

File tree

7 files changed

+618
-50
lines changed

7 files changed

+618
-50
lines changed

adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java

Lines changed: 193 additions & 33 deletions
Large diffs are not rendered by default.

adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@
1717
package org.apache.arrow.adapter.avro.producers;
1818

1919
import java.io.IOException;
20-
import org.apache.arrow.vector.IntVector;
20+
import org.apache.arrow.vector.BaseIntVector;
2121
import org.apache.avro.io.Encoder;
2222

2323
/**
24-
* Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an
25-
* Avro encoder.
24+
* Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data
25+
* to an Avro encoder.
2626
*/
27-
public class AvroEnumProducer extends BaseAvroProducer<IntVector> {
27+
public class AvroEnumProducer extends BaseAvroProducer<BaseIntVector> {
2828

2929
/** Instantiate an AvroEnumProducer. */
30-
public AvroEnumProducer(IntVector vector) {
30+
public AvroEnumProducer(BaseIntVector vector) {
3131
super(vector);
3232
}
3333

3434
@Override
3535
public void produce(Encoder encoder) throws IOException {
36-
encoder.writeEnum(vector.get(currentIndex++));
36+
encoder.writeEnum((int) vector.getValueAsLong(currentIndex++));
3737
}
3838
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.arrow.adapter.avro.producers;
18+
19+
import java.io.IOException;
20+
import org.apache.arrow.vector.BaseIntVector;
21+
import org.apache.arrow.vector.FieldVector;
22+
import org.apache.avro.io.Encoder;
23+
24+
/**
25+
* Producer that decodes values from a dictionary-encoded {@link FieldVector}, writes the resulting
26+
* values to an Avro encoder.
27+
*
28+
* @param <T> Type of the underlying dictionary vector
29+
*/
30+
public class DictionaryDecodingProducer<T extends FieldVector>
31+
extends BaseAvroProducer<BaseIntVector> {
32+
33+
private final Producer<T> dictProducer;
34+
35+
/** Instantiate a DictionaryDecodingProducer. */
36+
public DictionaryDecodingProducer(BaseIntVector indexVector, Producer<T> dictProducer) {
37+
super(indexVector);
38+
this.dictProducer = dictProducer;
39+
}
40+
41+
@Override
42+
public void produce(Encoder encoder) throws IOException {
43+
int dicIndex = (int) vector.getValueAsLong(currentIndex++);
44+
dictProducer.setPosition(dicIndex);
45+
dictProducer.produce(encoder);
46+
}
47+
}

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,14 @@
7676
import org.apache.arrow.vector.complex.StructVector;
7777
import org.apache.arrow.vector.complex.writer.BaseWriter;
7878
import org.apache.arrow.vector.complex.writer.FieldWriter;
79+
import org.apache.arrow.vector.dictionary.Dictionary;
80+
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
81+
import org.apache.arrow.vector.dictionary.DictionaryProvider;
7982
import org.apache.arrow.vector.types.DateUnit;
8083
import org.apache.arrow.vector.types.FloatingPointPrecision;
8184
import org.apache.arrow.vector.types.TimeUnit;
8285
import org.apache.arrow.vector.types.pojo.ArrowType;
86+
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
8387
import org.apache.arrow.vector.types.pojo.Field;
8488
import org.apache.arrow.vector.types.pojo.FieldType;
8589
import org.apache.arrow.vector.util.JsonStringArrayList;
@@ -2817,4 +2821,81 @@ record = datumReader.read(record, decoder);
28172821
}
28182822
}
28192823
}
2824+
2825+
@Test
2826+
public void testWriteDictEnumEncoded() throws Exception {
2827+
2828+
BufferAllocator allocator = new RootAllocator();
2829+
2830+
// Create a dictionary
2831+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
2832+
VarCharVector dictionaryVector =
2833+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
2834+
2835+
dictionaryVector.allocateNew(3);
2836+
dictionaryVector.set(0, "apple".getBytes());
2837+
dictionaryVector.set(1, "banana".getBytes());
2838+
dictionaryVector.set(2, "cherry".getBytes());
2839+
dictionaryVector.setValueCount(3);
2840+
2841+
Dictionary dictionary =
2842+
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
2843+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
2844+
2845+
// Field definition
2846+
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
2847+
VarCharVector stringVector =
2848+
new VarCharVector(new Field("enumField", stringField, null), allocator);
2849+
stringVector.allocateNew(10);
2850+
stringVector.setSafe(0, "apple".getBytes());
2851+
stringVector.setSafe(1, "banana".getBytes());
2852+
stringVector.setSafe(2, "cherry".getBytes());
2853+
stringVector.setSafe(3, "cherry".getBytes());
2854+
stringVector.setSafe(4, "apple".getBytes());
2855+
stringVector.setSafe(5, "banana".getBytes());
2856+
stringVector.setSafe(6, "apple".getBytes());
2857+
stringVector.setSafe(7, "cherry".getBytes());
2858+
stringVector.setSafe(8, "banana".getBytes());
2859+
stringVector.setSafe(9, "apple".getBytes());
2860+
stringVector.setValueCount(10);
2861+
2862+
IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);
2863+
2864+
// Set up VSR
2865+
List<FieldVector> vectors = Arrays.asList(encodedVector);
2866+
int rowCount = 10;
2867+
2868+
try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
2869+
2870+
File dataFile = new File(TMP, "testWriteEnumEncoded.avro");
2871+
2872+
// Write an AVRO block using the producer classes
2873+
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
2874+
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
2875+
CompositeAvroProducer producer =
2876+
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
2877+
for (int row = 0; row < rowCount; row++) {
2878+
producer.produce(encoder);
2879+
}
2880+
encoder.flush();
2881+
}
2882+
2883+
// Set up reading the AVRO block as a GenericRecord
2884+
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
2885+
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema);
2886+
2887+
try (InputStream inputStream = new FileInputStream(dataFile)) {
2888+
2889+
BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
2890+
GenericRecord record = null;
2891+
2892+
// Read and check values
2893+
for (int row = 0; row < rowCount; row++) {
2894+
record = datumReader.read(record, decoder);
2895+
// Values read from Avro should be the decoded enum values
2896+
assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
2897+
}
2898+
}
2899+
}
2900+
}
28202901
}

adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,18 @@
2020

2121
import java.util.Arrays;
2222
import java.util.List;
23+
import org.apache.arrow.memory.BufferAllocator;
24+
import org.apache.arrow.memory.RootAllocator;
25+
import org.apache.arrow.vector.BigIntVector;
26+
import org.apache.arrow.vector.VarCharVector;
27+
import org.apache.arrow.vector.dictionary.Dictionary;
28+
import org.apache.arrow.vector.dictionary.DictionaryProvider;
2329
import org.apache.arrow.vector.types.DateUnit;
2430
import org.apache.arrow.vector.types.FloatingPointPrecision;
2531
import org.apache.arrow.vector.types.TimeUnit;
2632
import org.apache.arrow.vector.types.UnionMode;
2733
import org.apache.arrow.vector.types.pojo.ArrowType;
34+
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
2835
import org.apache.arrow.vector.types.pojo.Field;
2936
import org.apache.arrow.vector.types.pojo.FieldType;
3037
import org.apache.avro.LogicalTypes;
@@ -1389,4 +1396,126 @@ public void testConvertUnionTypes() {
13891396
Schema.Type.STRING,
13901397
schema.getField("nullableDenseUnionField").schema().getTypes().get(3).getType());
13911398
}
1399+
1400+
@Test
1401+
public void testWriteDictEnumEncoded() {
1402+
1403+
BufferAllocator allocator = new RootAllocator();
1404+
1405+
// Create a dictionary
1406+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
1407+
VarCharVector dictionaryVector =
1408+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
1409+
1410+
dictionaryVector.allocateNew(3);
1411+
dictionaryVector.set(0, "apple".getBytes());
1412+
dictionaryVector.set(1, "banana".getBytes());
1413+
dictionaryVector.set(2, "cherry".getBytes());
1414+
dictionaryVector.setValueCount(3);
1415+
1416+
Dictionary dictionary =
1417+
new Dictionary(
1418+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1419+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1420+
1421+
List<Field> fields =
1422+
Arrays.asList(
1423+
new Field(
1424+
"enumField",
1425+
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
1426+
null));
1427+
1428+
Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
1429+
1430+
assertEquals(Schema.Type.RECORD, schema.getType());
1431+
assertEquals(1, schema.getFields().size());
1432+
1433+
Schema.Field enumField = schema.getField("enumField");
1434+
1435+
assertEquals(Schema.Type.ENUM, enumField.schema().getType());
1436+
assertEquals(3, enumField.schema().getEnumSymbols().size());
1437+
assertEquals("apple", enumField.schema().getEnumSymbols().get(0));
1438+
assertEquals("banana", enumField.schema().getEnumSymbols().get(1));
1439+
assertEquals("cherry", enumField.schema().getEnumSymbols().get(2));
1440+
}
1441+
1442+
@Test
1443+
public void testWriteDictEnumInvalid() {
1444+
1445+
BufferAllocator allocator = new RootAllocator();
1446+
1447+
// Create a dictionary
1448+
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
1449+
VarCharVector dictionaryVector =
1450+
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
1451+
1452+
dictionaryVector.allocateNew(3);
1453+
dictionaryVector.set(0, "passion fruit".getBytes());
1454+
dictionaryVector.set(1, "banana".getBytes());
1455+
dictionaryVector.set(2, "cherry".getBytes());
1456+
dictionaryVector.setValueCount(3);
1457+
1458+
Dictionary dictionary =
1459+
new Dictionary(
1460+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1461+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1462+
1463+
List<Field> fields =
1464+
Arrays.asList(
1465+
new Field(
1466+
"enumField",
1467+
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
1468+
null));
1469+
1470+
// Dictionary field contains values that are not valid enums
1471+
// Should be decoded and output as a string field
1472+
1473+
Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
1474+
1475+
assertEquals(Schema.Type.RECORD, schema.getType());
1476+
assertEquals(1, schema.getFields().size());
1477+
1478+
Schema.Field enumField = schema.getField("enumField");
1479+
assertEquals(Schema.Type.STRING, enumField.schema().getType());
1480+
}
1481+
1482+
@Test
1483+
public void testWriteDictEnumInvalid2() {
1484+
1485+
BufferAllocator allocator = new RootAllocator();
1486+
1487+
// Create a dictionary
1488+
FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null);
1489+
BigIntVector dictionaryVector =
1490+
new BigIntVector(new Field("dictionary", dictionaryField, null), allocator);
1491+
1492+
dictionaryVector.allocateNew(3);
1493+
dictionaryVector.set(0, 123L);
1494+
dictionaryVector.set(1, 456L);
1495+
dictionaryVector.set(2, 789L);
1496+
dictionaryVector.setValueCount(3);
1497+
1498+
Dictionary dictionary =
1499+
new Dictionary(
1500+
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
1501+
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
1502+
1503+
List<Field> fields =
1504+
Arrays.asList(
1505+
new Field(
1506+
"enumField",
1507+
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
1508+
null));
1509+
1510+
// Dictionary field encodes LONG values rather than STRING
1511+
// Should be doecded and output as a LONG field
1512+
1513+
Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
1514+
1515+
assertEquals(Schema.Type.RECORD, schema.getType());
1516+
assertEquals(1, schema.getFields().size());
1517+
1518+
Schema.Field enumField = schema.getField("enumField");
1519+
assertEquals(Schema.Type.LONG, enumField.schema().getType());
1520+
}
13921521
}

0 commit comments

Comments
 (0)