|
28 | 28 | import java.time.LocalDate; |
29 | 29 | import java.time.ZonedDateTime; |
30 | 30 | import java.util.Arrays; |
| 31 | +import java.util.HashMap; |
31 | 32 | import java.util.List; |
| 33 | +import java.util.Map; |
32 | 34 | import org.apache.arrow.adapter.avro.producers.CompositeAvroProducer; |
33 | 35 | import org.apache.arrow.memory.BufferAllocator; |
34 | 36 | import org.apache.arrow.memory.RootAllocator; |
|
47 | 49 | import org.apache.arrow.vector.IntVector; |
48 | 50 | import org.apache.arrow.vector.complex.FixedSizeListVector; |
49 | 51 | import org.apache.arrow.vector.complex.ListVector; |
| 52 | +import org.apache.arrow.vector.complex.MapVector; |
50 | 53 | import org.apache.arrow.vector.NullVector; |
51 | 54 | import org.apache.arrow.vector.SmallIntVector; |
52 | 55 | import org.apache.arrow.vector.TimeStampMicroTZVector; |
|
69 | 72 | import org.apache.arrow.vector.VarBinaryVector; |
70 | 73 | import org.apache.arrow.vector.VarCharVector; |
71 | 74 | import org.apache.arrow.vector.VectorSchemaRoot; |
| 75 | +import org.apache.arrow.vector.complex.writer.BaseWriter; |
72 | 76 | import org.apache.arrow.vector.types.DateUnit; |
73 | 77 | import org.apache.arrow.vector.types.FloatingPointPrecision; |
74 | 78 | import org.apache.arrow.vector.types.TimeUnit; |
75 | 79 | import org.apache.arrow.vector.types.pojo.ArrowType; |
76 | 80 | import org.apache.arrow.vector.types.pojo.Field; |
77 | 81 | import org.apache.arrow.vector.types.pojo.FieldType; |
| 82 | +import org.apache.arrow.vector.util.JsonStringArrayList; |
| 83 | +import org.apache.arrow.vector.util.JsonStringHashMap; |
78 | 84 | import org.apache.avro.Conversions; |
79 | 85 | import org.apache.avro.LogicalTypes; |
80 | 86 | import org.apache.avro.Schema; |
|
85 | 91 | import org.apache.avro.io.BinaryEncoder; |
86 | 92 | import org.apache.avro.io.DecoderFactory; |
87 | 93 | import org.apache.avro.io.EncoderFactory; |
| 94 | +import org.apache.avro.util.Utf8; |
88 | 95 | import org.junit.jupiter.api.Test; |
89 | 96 | import org.junit.jupiter.api.io.TempDir; |
90 | 97 |
|
@@ -2052,4 +2059,140 @@ record = datumReader.read(record, decoder); |
2052 | 2059 | } |
2053 | 2060 | } |
2054 | 2061 | } |
| 2062 | + |
| 2063 | + @Test |
| 2064 | + public void testWriteNonNullableMap() throws Exception { |
| 2065 | + |
| 2066 | + // Field definitions |
| 2067 | + FieldType intMapField = new FieldType(false, new ArrowType.Map(false), null); |
| 2068 | + FieldType stringMapField = new FieldType(false, new ArrowType.Map(false), null); |
| 2069 | + FieldType dateMapField = new FieldType(false, new ArrowType.Map(false), null); |
| 2070 | + |
| 2071 | + Field keyField = new Field("key", FieldType.notNullable(new ArrowType.Utf8()), null); |
| 2072 | + Field intField = new Field("value", FieldType.notNullable(new ArrowType.Int(32, true)), null); |
| 2073 | + Field stringField = new Field("value", FieldType.notNullable(new ArrowType.Utf8()), null); |
| 2074 | + Field dateField = new Field("value", FieldType.notNullable(new ArrowType.Date(DateUnit.DAY)), null); |
| 2075 | + |
| 2076 | + Field intEntryField = new Field("entries", FieldType.notNullable(new ArrowType.Struct()), Arrays.asList(keyField, intField)); |
| 2077 | + Field stringEntryField = new Field("entries", FieldType.notNullable(new ArrowType.Struct()), Arrays.asList(keyField, stringField)); |
| 2078 | + Field dateEntryField = new Field("entries", FieldType.notNullable(new ArrowType.Struct()), Arrays.asList(keyField, dateField)); |
| 2079 | + |
| 2080 | + // Create empty vectors |
| 2081 | + BufferAllocator allocator = new RootAllocator(); |
| 2082 | + MapVector intMapVector = new MapVector("intMap", allocator, intMapField, null); |
| 2083 | + MapVector stringMapVector = new MapVector("stringMap", allocator, stringMapField, null); |
| 2084 | + MapVector dateMapVector = new MapVector("dateMap", allocator, dateMapField, null); |
| 2085 | + |
| 2086 | + intMapVector.initializeChildrenFromFields(Arrays.asList(intEntryField)); |
| 2087 | + stringMapVector.initializeChildrenFromFields(Arrays.asList(stringEntryField)); |
| 2088 | + dateMapVector.initializeChildrenFromFields(Arrays.asList(dateEntryField)); |
| 2089 | + |
| 2090 | + // Set up VSR |
| 2091 | + List<FieldVector> vectors = Arrays.asList(intMapVector, stringMapVector, dateMapVector); |
| 2092 | + int rowCount = 3; |
| 2093 | + |
| 2094 | + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { |
| 2095 | + |
| 2096 | + root.setRowCount(rowCount); |
| 2097 | + root.allocateNew(); |
| 2098 | + |
| 2099 | + // Set test data for intList |
| 2100 | + BaseWriter.MapWriter writer = intMapVector.getWriter(); |
| 2101 | + for (int i = 0; i < rowCount; i++) { |
| 2102 | + writer.startMap(); |
| 2103 | + for (int j = 0; j < 5 - i; j++) { |
| 2104 | + writer.startEntry(); |
| 2105 | + writer.key().varChar().writeVarChar("key" + j); |
| 2106 | + writer.value().integer().writeInt(j); |
| 2107 | + writer.endEntry(); |
| 2108 | + } |
| 2109 | + writer.endMap(); |
| 2110 | + } |
| 2111 | + |
| 2112 | + // Set test data for stringList |
| 2113 | + BaseWriter.MapWriter stringWriter = stringMapVector.getWriter(); |
| 2114 | + for (int i = 0; i < rowCount; i++) { |
| 2115 | + stringWriter.startMap(); |
| 2116 | + for (int j = 0; j < 5 - i; j++) { |
| 2117 | + stringWriter.startEntry(); |
| 2118 | + stringWriter.key().varChar().writeVarChar("key" + j); |
| 2119 | + stringWriter.value().varChar().writeVarChar("string" + j); |
| 2120 | + stringWriter.endEntry(); |
| 2121 | + } |
| 2122 | + stringWriter.endMap(); |
| 2123 | + } |
| 2124 | + |
| 2125 | + // Set test data for dateList |
| 2126 | + BaseWriter.MapWriter dateWriter = dateMapVector.getWriter(); |
| 2127 | + for (int i = 0; i < rowCount; i++) { |
| 2128 | + dateWriter.startMap(); |
| 2129 | + for (int j = 0; j < 5 - i; j++) { |
| 2130 | + dateWriter.startEntry(); |
| 2131 | + dateWriter.key().varChar().writeVarChar("key" + j); |
| 2132 | + dateWriter.value().dateDay().writeDateDay((int) LocalDate.now().plusDays(j).toEpochDay()); |
| 2133 | + dateWriter.endEntry(); |
| 2134 | + } |
| 2135 | + dateWriter.endMap(); |
| 2136 | + } |
| 2137 | + |
| 2138 | + File dataFile = new File(TMP, "testWriteNonNullableMap.avro"); |
| 2139 | + |
| 2140 | + // Write an AVRO block using the producer classes |
| 2141 | + try (FileOutputStream fos = new FileOutputStream(dataFile)) { |
| 2142 | + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); |
| 2143 | + CompositeAvroProducer producer = ArrowToAvroUtils.createCompositeProducer(vectors); |
| 2144 | + for (int row = 0; row < rowCount; row++) { |
| 2145 | + producer.produce(encoder); |
| 2146 | + } |
| 2147 | + encoder.flush(); |
| 2148 | + } |
| 2149 | + |
| 2150 | + // Set up reading the AVRO block as a GenericRecord |
| 2151 | + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields()); |
| 2152 | + GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema); |
| 2153 | + |
| 2154 | + try (InputStream inputStream = new FileInputStream(dataFile)) { |
| 2155 | + |
| 2156 | + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); |
| 2157 | + GenericRecord record = null; |
| 2158 | + |
| 2159 | + // Read and check values |
| 2160 | + for (int row = 0; row < rowCount; row++) { |
| 2161 | + record = datumReader.read(record, decoder); |
| 2162 | + Map<String, Object> intMap = convertMap(intMapVector.getObject(row)); |
| 2163 | + Map<String, Object> stringMap = convertMap(stringMapVector.getObject(row)); |
| 2164 | + Map<String, Object> dateMap = convertMap(dateMapVector.getObject(row)); |
| 2165 | + compareMaps(intMap, (Map) record.get("intMap")); |
| 2166 | + compareMaps(stringMap, (Map) record.get("stringMap")); |
| 2167 | + compareMaps(dateMap, (Map) record.get("dateMap")); |
| 2168 | + } |
| 2169 | + } |
| 2170 | + } |
| 2171 | + } |
| 2172 | + |
| 2173 | + private Map<String, Object> convertMap(List<?> entryList) { |
| 2174 | + |
| 2175 | + Map<String, Object> map = new HashMap<>(); |
| 2176 | + JsonStringArrayList<?> structList = (JsonStringArrayList<?>) entryList; |
| 2177 | + for (Object entry : structList) { |
| 2178 | + JsonStringHashMap<String, ?> structEntry = (JsonStringHashMap<String, ?>) entry; |
| 2179 | + String key = structEntry.get(MapVector.KEY_NAME).toString(); |
| 2180 | + Object value = structEntry.get(MapVector.VALUE_NAME); |
| 2181 | + map.put(key, value); |
| 2182 | + } |
| 2183 | + return map; |
| 2184 | + } |
| 2185 | + |
| 2186 | + private void compareMaps(Map<String, ?> expected, Map<?, ?> actual) { |
| 2187 | + assertEquals(expected.size(), actual.size()); |
| 2188 | + for (Object key : actual.keySet()) { |
| 2189 | + assertTrue(expected.containsKey(key.toString())); |
| 2190 | + Object actualValue = actual.get(key); |
| 2191 | + if (actualValue instanceof Utf8) { |
| 2192 | + assertEquals(expected.get(key.toString()).toString(), actualValue.toString()); |
| 2193 | + } else { |
| 2194 | + assertEquals(expected.get(key.toString()), actual.get(key)); |
| 2195 | + } |
| 2196 | + } |
| 2197 | + } |
2055 | 2198 | } |
0 commit comments