|
76 | 76 | import org.apache.arrow.vector.complex.StructVector; |
77 | 77 | import org.apache.arrow.vector.complex.writer.BaseWriter; |
78 | 78 | import org.apache.arrow.vector.complex.writer.FieldWriter; |
| 79 | +import org.apache.arrow.vector.dictionary.Dictionary; |
| 80 | +import org.apache.arrow.vector.dictionary.DictionaryEncoder; |
| 81 | +import org.apache.arrow.vector.dictionary.DictionaryProvider; |
79 | 82 | import org.apache.arrow.vector.types.DateUnit; |
80 | 83 | import org.apache.arrow.vector.types.FloatingPointPrecision; |
81 | 84 | import org.apache.arrow.vector.types.TimeUnit; |
82 | 85 | import org.apache.arrow.vector.types.pojo.ArrowType; |
| 86 | +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; |
83 | 87 | import org.apache.arrow.vector.types.pojo.Field; |
84 | 88 | import org.apache.arrow.vector.types.pojo.FieldType; |
85 | 89 | import org.apache.arrow.vector.util.JsonStringArrayList; |
@@ -2817,4 +2821,184 @@ record = datumReader.read(record, decoder); |
2817 | 2821 | } |
2818 | 2822 | } |
2819 | 2823 | } |
| 2824 | + |
| 2825 | + @Test |
| 2826 | + public void testWriteDictEnumEncoded() throws Exception { |
| 2827 | + |
| 2828 | + BufferAllocator allocator = new RootAllocator(); |
| 2829 | + |
| 2830 | + // Create a dictionary |
| 2831 | + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); |
| 2832 | + VarCharVector dictionaryVector = |
| 2833 | + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); |
| 2834 | + |
| 2835 | + dictionaryVector.allocateNew(3); |
| 2836 | + dictionaryVector.set(0, "apple".getBytes()); |
| 2837 | + dictionaryVector.set(1, "banana".getBytes()); |
| 2838 | + dictionaryVector.set(2, "cherry".getBytes()); |
| 2839 | + dictionaryVector.setValueCount(3); |
| 2840 | + |
| 2841 | + Dictionary dictionary = |
| 2842 | + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); |
| 2843 | + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); |
| 2844 | + |
| 2845 | + // Field definition |
| 2846 | + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); |
| 2847 | + VarCharVector stringVector = |
| 2848 | + new VarCharVector(new Field("enumField", stringField, null), allocator); |
| 2849 | + stringVector.allocateNew(10); |
| 2850 | + stringVector.setSafe(0, "apple".getBytes()); |
| 2851 | + stringVector.setSafe(1, "banana".getBytes()); |
| 2852 | + stringVector.setSafe(2, "cherry".getBytes()); |
| 2853 | + stringVector.setSafe(3, "cherry".getBytes()); |
| 2854 | + stringVector.setSafe(4, "apple".getBytes()); |
| 2855 | + stringVector.setSafe(5, "banana".getBytes()); |
| 2856 | + stringVector.setSafe(6, "apple".getBytes()); |
| 2857 | + stringVector.setSafe(7, "cherry".getBytes()); |
| 2858 | + stringVector.setSafe(8, "banana".getBytes()); |
| 2859 | + stringVector.setSafe(9, "apple".getBytes()); |
| 2860 | + stringVector.setValueCount(10); |
| 2861 | + |
| 2862 | + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); |
| 2863 | + |
| 2864 | + // Set up VSR |
| 2865 | + List<FieldVector> vectors = Arrays.asList(encodedVector); |
| 2866 | + int rowCount = 10; |
| 2867 | + |
| 2868 | + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { |
| 2869 | + |
| 2870 | + File dataFile = new File(TMP, "testWriteEnumEncoded.avro"); |
| 2871 | + |
| 2872 | + // Write an AVRO block using the producer classes |
| 2873 | + try (FileOutputStream fos = new FileOutputStream(dataFile)) { |
| 2874 | + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); |
| 2875 | + CompositeAvroProducer producer = |
| 2876 | + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); |
| 2877 | + for (int row = 0; row < rowCount; row++) { |
| 2878 | + producer.produce(encoder); |
| 2879 | + } |
| 2880 | + encoder.flush(); |
| 2881 | + } |
| 2882 | + |
| 2883 | + // Set up reading the AVRO block as a GenericRecord |
| 2884 | + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); |
| 2885 | + GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema); |
| 2886 | + |
| 2887 | + try (InputStream inputStream = new FileInputStream(dataFile)) { |
| 2888 | + |
| 2889 | + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); |
| 2890 | + GenericRecord record = null; |
| 2891 | + |
| 2892 | + // Read and check values |
| 2893 | + for (int row = 0; row < rowCount; row++) { |
| 2894 | + record = datumReader.read(record, decoder); |
| 2895 | + // Values read from Avro should be the decoded enum values |
| 2896 | + assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); |
| 2897 | + } |
| 2898 | + } |
| 2899 | + } |
| 2900 | + } |
| 2901 | + |
| 2902 | + @Test |
| 2903 | + public void testWriteEnumDecoded() throws Exception { |
| 2904 | + |
| 2905 | + // Dict encoded fields that are not valid Avro enums should be decoded on write |
| 2906 | + |
| 2907 | + BufferAllocator allocator = new RootAllocator(); |
| 2908 | + |
| 2909 | + // Create a dictionary |
| 2910 | + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); |
| 2911 | + VarCharVector dictionaryVector = |
| 2912 | + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); |
| 2913 | + |
| 2914 | + dictionaryVector.allocateNew(3); |
| 2915 | + dictionaryVector.set(0, "passion fruit".getBytes()); // spaced not allowed |
| 2916 | + dictionaryVector.set(1, "banana".getBytes()); |
| 2917 | + dictionaryVector.set(2, "cherry".getBytes()); |
| 2918 | + dictionaryVector.setValueCount(3); |
| 2919 | + |
| 2920 | + Dictionary dictionary = |
| 2921 | + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); |
| 2922 | + |
| 2923 | + FieldType dictionaryField2 = new FieldType(false, new ArrowType.Int(64, true), null); |
| 2924 | + BigIntVector dictionaryVector2 = |
| 2925 | + new BigIntVector(new Field("dictionary2", dictionaryField2, null), allocator); |
| 2926 | + |
| 2927 | + dictionaryVector2.allocateNew(3); |
| 2928 | + dictionaryVector2.set(0, 0L); |
| 2929 | + dictionaryVector2.set(1, 1L); |
| 2930 | + dictionaryVector2.set(2, 2L); |
| 2931 | + dictionaryVector2.setValueCount(3); |
| 2932 | + |
| 2933 | + Dictionary dictionary2 = |
| 2934 | + new Dictionary(dictionaryVector2, new DictionaryEncoding(2L, false, null)); |
| 2935 | + |
| 2936 | + DictionaryProvider dictionaries = |
| 2937 | + new DictionaryProvider.MapDictionaryProvider(dictionary, dictionary2); |
| 2938 | + |
| 2939 | + // Field definition |
| 2940 | + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); |
| 2941 | + VarCharVector stringVector = |
| 2942 | + new VarCharVector(new Field("enumField", stringField, null), allocator); |
| 2943 | + stringVector.allocateNew(10); |
| 2944 | + stringVector.setSafe(0, "passion fruit".getBytes()); |
| 2945 | + stringVector.setSafe(1, "banana".getBytes()); |
| 2946 | + stringVector.setSafe(2, "cherry".getBytes()); |
| 2947 | + stringVector.setSafe(3, "cherry".getBytes()); |
| 2948 | + stringVector.setSafe(4, "passion fruit".getBytes()); |
| 2949 | + stringVector.setSafe(5, "banana".getBytes()); |
| 2950 | + stringVector.setSafe(6, "passion fruit".getBytes()); |
| 2951 | + stringVector.setSafe(7, "cherry".getBytes()); |
| 2952 | + stringVector.setSafe(8, "banana".getBytes()); |
| 2953 | + stringVector.setSafe(9, "passion fruit".getBytes()); |
| 2954 | + stringVector.setValueCount(10); |
| 2955 | + |
| 2956 | + FieldType longField = new FieldType(false, new ArrowType.Int(64, true), null); |
| 2957 | + BigIntVector longVector = new BigIntVector(new Field("enumField2", longField, null), allocator); |
| 2958 | + longVector.allocateNew(10); |
| 2959 | + for (int i = 0; i < 10; i++) { |
| 2960 | + longVector.setSafe(i, (long) i % 3); |
| 2961 | + } |
| 2962 | + longVector.setValueCount(10); |
| 2963 | + |
| 2964 | + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); |
| 2965 | + IntVector encodedVector2 = (IntVector) DictionaryEncoder.encode(longVector, dictionary2); |
| 2966 | + |
| 2967 | + // Set up VSR |
| 2968 | + List<FieldVector> vectors = Arrays.asList(encodedVector, encodedVector2); |
| 2969 | + int rowCount = 10; |
| 2970 | + |
| 2971 | + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { |
| 2972 | + |
| 2973 | + File dataFile = new File(TMP, "testWriteEnumDecodedavro"); |
| 2974 | + |
| 2975 | + // Write an AVRO block using the producer classes |
| 2976 | + try (FileOutputStream fos = new FileOutputStream(dataFile)) { |
| 2977 | + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); |
| 2978 | + CompositeAvroProducer producer = |
| 2979 | + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); |
| 2980 | + for (int row = 0; row < rowCount; row++) { |
| 2981 | + producer.produce(encoder); |
| 2982 | + } |
| 2983 | + encoder.flush(); |
| 2984 | + } |
| 2985 | + |
| 2986 | + // Set up reading the AVRO block as a GenericRecord |
| 2987 | + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); |
| 2988 | + GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema); |
| 2989 | + |
| 2990 | + try (InputStream inputStream = new FileInputStream(dataFile)) { |
| 2991 | + |
| 2992 | + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); |
| 2993 | + GenericRecord record = null; |
| 2994 | + |
| 2995 | + // Read and check values |
| 2996 | + for (int row = 0; row < rowCount; row++) { |
| 2997 | + record = datumReader.read(record, decoder); |
| 2998 | + assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); |
| 2999 | + assertEquals(longVector.getObject(row), record.get("enumField2")); |
| 3000 | + } |
| 3001 | + } |
| 3002 | + } |
| 3003 | + } |
2820 | 3004 | } |
0 commit comments