Skip to content

Commit 83d9dca

Browse files
committed
Fix array and type issues, add tests, spotless
1 parent a65d25e commit 83d9dca

File tree

2 files changed

+164
-54
lines changed

2 files changed

+164
-54
lines changed

parquet-avro/src/test/java/org/apache/parquet/avro/TestWriteVariant.java

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,16 @@
2626
import java.io.IOException;
2727
import java.math.BigDecimal;
2828
import java.nio.ByteBuffer;
29-
import java.nio.ByteOrder;
3029
import java.util.*;
31-
import java.util.concurrent.Callable;
3230
import java.util.function.Consumer;
33-
import com.google.common.collect.ImmutableMap;
3431
import org.apache.avro.Schema;
3532
import org.apache.avro.generic.GenericData;
3633
import org.apache.avro.generic.GenericRecord;
37-
import org.apache.avro.generic.IndexedRecord;
3834
import org.apache.hadoop.conf.Configuration;
3935
import org.apache.hadoop.fs.Path;
4036
import org.apache.parquet.DirectWriterTest;
41-
import org.apache.parquet.Preconditions;
42-
import org.apache.parquet.conf.ParquetConfiguration;
43-
import org.apache.parquet.conf.PlainParquetConfiguration;
4437
import org.apache.parquet.hadoop.ParquetWriter;
4538
import org.apache.parquet.hadoop.api.WriteSupport;
46-
import org.apache.parquet.io.ParquetDecodingException;
47-
import org.apache.parquet.io.api.Binary;
48-
import org.apache.parquet.io.api.RecordConsumer;
4939
import org.apache.parquet.schema.*;
5040
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
5141
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
@@ -89,8 +79,7 @@ private static ByteBuffer variant(String s) {
8979
return variant(b -> b.appendString(s));
9080
}
9181

92-
private static GroupType variantGroup =
93-
Types.buildGroup(Type.Repetition.REQUIRED)
82+
private static GroupType variantGroup = Types.buildGroup(Type.Repetition.REQUIRED)
9483
.as(LogicalTypeAnnotation.variantType((byte) 1))
9584
.required(PrimitiveTypeName.BINARY)
9685
.named("metadata")
@@ -114,43 +103,83 @@ private static MessageType parquetSchema(GroupType variantGroup) {
114103
private ByteBuffer TEST_METADATA;
115104
private ByteBuffer TEST_OBJECT;
116105
private ByteBuffer SIMILAR_OBJECT;
106+
private ByteBuffer TEST_ARRAY;
107+
private ByteBuffer SIMILAR_ARRAY;
117108
private ByteBuffer EMPTY_OBJECT;
118109
private ByteBuffer EMPTY_METADATA = fullVariant(b -> b.appendNull()).getMetadataRawBytes();
119110
private Variant[] VARIANTS;
120111

121112
public TestWriteVariant() throws Exception {
122113
TEST_METADATA = fullVariant(b -> {
114+
VariantObjectBuilder ob = b.startObject();
115+
ob.appendKey("a");
116+
ob.appendNull();
117+
ob.appendKey("b");
118+
ob.appendNull();
119+
ob.appendKey("c");
120+
ob.appendNull();
121+
ob.appendKey("d");
122+
ob.appendNull();
123+
ob.appendKey("e");
124+
ob.appendNull();
125+
b.endObject();
126+
})
127+
.getMetadataRawBytes();
128+
129+
TEST_OBJECT = variant(TEST_METADATA, b -> {
123130
VariantObjectBuilder ob = b.startObject();
124131
ob.appendKey("a");
125132
ob.appendNull();
126-
ob.appendKey("b");
127-
ob.appendNull();
128-
ob.appendKey("c");
129-
ob.appendNull();
130133
ob.appendKey("d");
131-
ob.appendNull();
132-
ob.appendKey("e");
133-
ob.appendNull();
134+
ob.appendString("iceberg");
134135
b.endObject();
136+
});
135137

136-
}).getMetadataRawBytes();
137-
138-
TEST_OBJECT = variant(TEST_METADATA, b -> {
138+
SIMILAR_OBJECT = variant(TEST_METADATA, b -> {
139139
VariantObjectBuilder ob = b.startObject();
140140
ob.appendKey("a");
141+
ob.appendInt(123456789);
142+
ob.appendKey("c");
143+
ob.appendString("string");
144+
b.endObject();
145+
});
146+
147+
// The first array element defines the schema.
148+
TEST_ARRAY = variant(TEST_METADATA, b -> {
149+
VariantArrayBuilder ab = b.startArray();
150+
VariantObjectBuilder ob = ab.startObject();
151+
ob.appendKey("a");
141152
ob.appendNull();
142153
ob.appendKey("d");
143154
ob.appendString("iceberg");
144-
b.endObject();
155+
ab.endObject();
156+
ab.appendInt(123);
157+
VariantObjectBuilder ob2 = ab.startObject();
158+
ob2.appendKey("c");
159+
ob2.appendString("hello");
160+
ob2.appendKey("d");
161+
ob2.appendDate(12345);
162+
ab.endObject();
163+
b.endArray();
145164
});
146165

147-
SIMILAR_OBJECT = variant(TEST_METADATA, b -> {
148-
VariantObjectBuilder ob = b.startObject();
149-
ob.appendKey("a");
150-
ob.appendInt(123456789);
151-
ob.appendKey("c");
152-
ob.appendString("string");
153-
b.endObject();
166+
// Change one field name and one type in the first element to change the schema.
167+
SIMILAR_ARRAY = variant(TEST_METADATA, b -> {
168+
VariantArrayBuilder ab = b.startArray();
169+
VariantObjectBuilder ob = ab.startObject();
170+
ob.appendKey("c");
171+
ob.appendString("iceberg");
172+
ob.appendKey("a");
173+
ob.appendString("parquet");
174+
ab.endObject();
175+
ab.appendInt(123);
176+
VariantObjectBuilder ob2 = ab.startObject();
177+
ob2.appendKey("c");
178+
ob2.appendString("hello");
179+
ob2.appendKey("d");
180+
ob2.appendDate(12345);
181+
ab.endObject();
182+
b.endArray();
154183
});
155184

156185
EMPTY_OBJECT = variant(TEST_METADATA, b -> {
@@ -177,6 +206,8 @@ public TestWriteVariant() throws Exception {
177206
new Variant(EMPTY_OBJECT, EMPTY_METADATA),
178207
new Variant(TEST_OBJECT, TEST_METADATA),
179208
new Variant(SIMILAR_OBJECT, TEST_METADATA),
209+
new Variant(TEST_ARRAY, TEST_METADATA),
210+
new Variant(SIMILAR_ARRAY, TEST_METADATA),
180211
fullVariant(b -> b.appendDate(12345)),
181212
fullVariant(b -> b.appendDate(-12345)),
182213
fullVariant(b -> b.appendTimestampTz(1234567890L)),
@@ -214,6 +245,7 @@ GenericRecord createRecord(int i, Variant v) {
214245
return record;
215246
}
216247

248+
// Tests in this file are based on Iceberg's TestVariantWriters suite.
217249
@Test
218250
public void testUnshreddedValues() throws IOException {
219251
for (Variant v : VARIANTS) {
@@ -238,7 +270,10 @@ public void testShreddedValues() throws IOException {
238270
GenericRecord actual = writeAndRead(testSchema, record);
239271
assertEquals(record.get(0), actual.get(0));
240272
assertEquals(((GenericRecord) record.get(1)).get(0), ((GenericRecord) actual.get(1)).get(0));
241-
assertEquals(((GenericRecord) record.get(1)).get(1), ((GenericRecord) actual.get(1)).get(1));
273+
// assertEquals(((GenericRecord) record.get(1)).get(1), ((GenericRecord) actual.get(1)).get(1));
274+
if (!((GenericRecord) record.get(1)).get(1).equals(((GenericRecord) actual.get(1)).get(1))) {
275+
assertTrue(false);
276+
}
242277
}
243278
}
244279

@@ -292,28 +327,24 @@ protected TestWriterBuilder self() {
292327

293328
@Override
294329
protected WriteSupport<GenericRecord> getWriteSupport(Configuration conf) {
295-
return new AvroWriteSupport<>(
296-
schema,
297-
new AvroSchemaConverter().convert(schema),
298-
GenericData.get());
330+
return new AvroWriteSupport<>(schema, new AvroSchemaConverter().convert(schema), GenericData.get());
299331
}
300332
}
301333

302334
GenericRecord writeAndRead(TestSchema testSchema, GenericRecord record) throws IOException {
303335
List<GenericRecord> result = writeAndRead(testSchema, Arrays.asList(record));
304-
assert(result.size() == 1);
336+
assert (result.size() == 1);
305337
return result.get(0);
306338
}
307339

308-
private List<GenericRecord> writeAndRead(
309-
TestSchema testSchema, List<GenericRecord> records) throws IOException {
340+
private List<GenericRecord> writeAndRead(TestSchema testSchema, List<GenericRecord> records) throws IOException {
310341
File tmp = File.createTempFile(getClass().getSimpleName(), ".tmp");
311342
tmp.deleteOnExit();
312343
tmp.delete();
313344
Path path = new Path(tmp.getPath());
314345

315346
try (ParquetWriter<GenericRecord> writer =
316-
new TestWriterBuilder(path).withFileType(testSchema.writeSchema).build()) {
347+
new TestWriterBuilder(path).withFileType(testSchema.writeSchema).build()) {
317348
for (GenericRecord record : records) {
318349
writer.write(record);
319350
}
@@ -385,9 +416,13 @@ private static Type shreddedType(Variant v) {
385416
case BOOLEAN:
386417
return Types.optional(BOOLEAN).named("typed_value");
387418
case BYTE:
388-
return Types.optional(INT32).as(LogicalTypeAnnotation.intType(8)).named("typed_value");
419+
return Types.optional(INT32)
420+
.as(LogicalTypeAnnotation.intType(8))
421+
.named("typed_value");
389422
case SHORT:
390-
return Types.optional(INT32).as(LogicalTypeAnnotation.intType(16)).named("typed_value");
423+
return Types.optional(INT32)
424+
.as(LogicalTypeAnnotation.intType(16))
425+
.named("typed_value");
391426
case INT:
392427
return Types.optional(INT32).named("typed_value");
393428
case LONG:

parquet-variant/src/main/java/org/apache/parquet/variant/VariantValueWriter.java

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import java.nio.ByteBuffer;
2222
import java.util.HashMap;
23-
import java.util.Map;
2423
import org.apache.parquet.io.api.Binary;
2524
import org.apache.parquet.io.api.RecordConsumer;
2625
import org.apache.parquet.schema.GroupType;
@@ -50,11 +49,19 @@ HashMap<String, Integer> getMetadataMap() {
5049
return metadataMap;
5150
}
5251

52+
/**
53+
* Write a Variant value to a shredded schema. The caller is responsible for calling startGroup()
54+
* and endGroup(), and for writing metadata.
55+
*/
5356
public static void write(RecordConsumer recordConsumer, GroupType schema, Variant value) {
5457
VariantValueWriter writer = new VariantValueWriter(recordConsumer, value.getMetadataRawBytes());
5558
writer.write(schema, value);
5659
}
5760

61+
/**
62+
* Write a Variant value to a shredded schema. The caller is responsible for calling startGroup()
63+
* and endGroup().
64+
*/
5865
void write(GroupType schema, Variant value) {
5966
Type typedValueField = null;
6067
if (schema.containsField("typed_value")) {
@@ -70,7 +77,8 @@ void write(GroupType schema, Variant value) {
7077
byte[] residual = null;
7178
if (typedValueField.isPrimitive()) {
7279
writeScalarValue(recordConsumer, value, typedValueField.asPrimitiveType());
73-
} else if (typedValueField.getLogicalTypeAnnotation() instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
80+
} else if (typedValueField.getLogicalTypeAnnotation()
81+
instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
7482
writeArrayValue(recordConsumer, value, typedValueField.asGroupType());
7583
} else {
7684
residual = writeObjectValue(recordConsumer, value, typedValueField.asGroupType());
@@ -96,38 +104,73 @@ private boolean isTypeCompatible(Variant.Type variantType, Type typedValueField)
96104
return false;
97105
}
98106
if (typedValueField.isPrimitive()) {
99-
// TODO: Expand for all the logical type annotations.
100107
PrimitiveType primitiveType = typedValueField.asPrimitiveType();
108+
LogicalTypeAnnotation logicalType = primitiveType.getLogicalTypeAnnotation();
109+
PrimitiveType.PrimitiveTypeName primitiveTypeName = primitiveType.getPrimitiveTypeName();
110+
101111
switch (variantType) {
102112
case BOOLEAN:
103-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BOOLEAN;
113+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BOOLEAN;
104114
case BYTE:
115+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
116+
&& logicalType instanceof LogicalTypeAnnotation.IntLogicalTypeAnnotation
117+
&& ((LogicalTypeAnnotation.IntLogicalTypeAnnotation) logicalType).getBitWidth() == 8;
105118
case SHORT:
119+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
120+
&& logicalType instanceof LogicalTypeAnnotation.IntLogicalTypeAnnotation
121+
&& ((LogicalTypeAnnotation.IntLogicalTypeAnnotation) logicalType).getBitWidth() == 16;
106122
case INT:
107-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT32;
123+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
124+
&& (logicalType == null
125+
|| logicalType instanceof LogicalTypeAnnotation.IntLogicalTypeAnnotation);
108126
case LONG:
109-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT64;
127+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
128+
&& (logicalType == null
129+
|| logicalType instanceof LogicalTypeAnnotation.IntLogicalTypeAnnotation);
110130
case FLOAT:
111-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.FLOAT;
131+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.FLOAT;
112132
case DOUBLE:
113-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.DOUBLE;
133+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.DOUBLE;
134+
case DECIMAL4:
135+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
136+
&& logicalType instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
137+
case DECIMAL8:
138+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
139+
&& logicalType instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
140+
case DECIMAL16:
141+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY
142+
&& logicalType instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
143+
case DATE:
144+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
145+
&& logicalType instanceof LogicalTypeAnnotation.DateLogicalTypeAnnotation;
146+
case TIME:
147+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
148+
&& logicalType instanceof LogicalTypeAnnotation.TimeLogicalTypeAnnotation;
149+
case TIMESTAMP_TZ:
150+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
151+
&& logicalType instanceof LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
152+
&& ((LogicalTypeAnnotation.TimestampLogicalTypeAnnotation) logicalType).isAdjustedToUTC();
153+
case TIMESTAMP_NTZ:
154+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
155+
&& logicalType instanceof LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
156+
&& !((LogicalTypeAnnotation.TimestampLogicalTypeAnnotation) logicalType).isAdjustedToUTC();
114157
case STRING:
115-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY &&
116-
primitiveType.getLogicalTypeAnnotation() instanceof LogicalTypeAnnotation.StringLogicalTypeAnnotation;
158+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY
159+
&& logicalType instanceof LogicalTypeAnnotation.StringLogicalTypeAnnotation;
117160
case BINARY:
118-
return primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY;
161+
return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY && logicalType == null;
119162
default:
120163
return false;
121164
}
122-
} else if (typedValueField.getLogicalTypeAnnotation() instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
165+
} else if (typedValueField.getLogicalTypeAnnotation()
166+
instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
123167
return variantType == Variant.Type.ARRAY;
124168
} else {
125169
return variantType == Variant.Type.OBJECT;
126170
}
127171
}
128172

129173
private void writeScalarValue(RecordConsumer recordConsumer, Variant variant, PrimitiveType type) {
130-
// TODO: Expand for all the types.
131174
switch (variant.getType()) {
132175
case BOOLEAN:
133176
recordConsumer.addBoolean(variant.getBoolean());
@@ -150,6 +193,34 @@ private void writeScalarValue(RecordConsumer recordConsumer, Variant variant, Pr
150193
case DOUBLE:
151194
recordConsumer.addDouble(variant.getDouble());
152195
break;
196+
case DECIMAL4:
197+
recordConsumer.addInteger(variant.getDecimal().unscaledValue().intValue());
198+
break;
199+
case DECIMAL8:
200+
recordConsumer.addLong(variant.getDecimal().unscaledValue().longValue());
201+
break;
202+
case DECIMAL16:
203+
recordConsumer.addBinary(Binary.fromConstantByteArray(
204+
variant.getDecimal().unscaledValue().toByteArray()));
205+
break;
206+
case DATE:
207+
recordConsumer.addInteger(variant.getInt());
208+
break;
209+
case TIME:
210+
recordConsumer.addLong(variant.getLong());
211+
break;
212+
case TIMESTAMP_TZ:
213+
recordConsumer.addLong(variant.getLong());
214+
break;
215+
case TIMESTAMP_NTZ:
216+
recordConsumer.addLong(variant.getLong());
217+
break;
218+
case TIMESTAMP_NANOS_TZ:
219+
recordConsumer.addLong(variant.getLong());
220+
break;
221+
case TIMESTAMP_NANOS_NTZ:
222+
recordConsumer.addLong(variant.getLong());
223+
break;
153224
case STRING:
154225
recordConsumer.addBinary(Binary.fromString(variant.getString()));
155226
break;
@@ -170,6 +241,8 @@ private void writeArrayValue(RecordConsumer recordConsumer, Variant variant, Gro
170241
GroupType listType = arrayType.getType(0).asGroupType();
171242
Type elementType = listType.getType(0);
172243

244+
recordConsumer.startGroup();
245+
recordConsumer.startField(arrayType.getFieldName(0), 0);
173246
// Write each array element
174247
for (int i = 0; i < variant.numArrayElements(); i++) {
175248
recordConsumer.startGroup();
@@ -182,6 +255,8 @@ private void writeArrayValue(RecordConsumer recordConsumer, Variant variant, Gro
182255
recordConsumer.endField("element", 0);
183256
recordConsumer.endGroup();
184257
}
258+
recordConsumer.endField(arrayType.getFieldName(0), 0);
259+
recordConsumer.endGroup();
185260
}
186261

187262
/**

0 commit comments

Comments
 (0)