Skip to content

Commit a65d25e

Browse files
committed
Fix a bunch of issues, still WIP
1 parent 7c4c3ad commit a65d25e

File tree

4 files changed

+259
-50
lines changed

4 files changed

+259
-50
lines changed

parquet-avro/src/test/java/org/apache/parquet/avro/TestWriteVariant.java

Lines changed: 163 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.parquet.avro;
2020

2121
import static org.apache.parquet.avro.AvroTestUtil.*;
22+
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*;
2223
import static org.junit.Assert.*;
2324

2425
import java.io.File;
@@ -99,7 +100,7 @@ private static ByteBuffer variant(String s) {
99100

100101
private static MessageType parquetSchema(GroupType variantGroup) {
101102
return Types.buildMessage()
102-
.required(PrimitiveTypeName.INT32)
103+
.required(INT32)
103104
.named("id")
104105
.addField(variantGroup)
105106
.named("table");
@@ -231,7 +232,7 @@ public void testUnshreddedValues() throws IOException {
231232
public void testShreddedValues() throws IOException {
232233
for (Variant v : VARIANTS) {
233234
GenericRecord record = createRecord(1, v);
234-
MessageType writeSchema = shreddingFromValue(v);
235+
MessageType writeSchema = shreddingSchema(v);
235236
TestSchema testSchema = new TestSchema(writeSchema, readSchema);
236237

237238
GenericRecord actual = writeAndRead(testSchema, record);
@@ -241,24 +242,21 @@ public void testShreddedValues() throws IOException {
241242
}
242243
}
243244

244-
// @Test
245-
// public void testMixedShredding() throws IOException {
246-
// for (Variant v : VARIANTS) {
247-
// List<Record> expected =
248-
// IntStream.range(0, VARIANTS.length)
249-
// .mapToObj(i -> RECORD.copy("id", i, "var", VARIANTS[i]))
250-
// .collect(Collectors.toList());
251-
//
252-
// List<Record> actual =
253-
// writeAndRead((id, name) -> ParquetVariantUtil.toParquetSchema(v.value()), expected);
254-
//
255-
// assertThat(actual.size()).isEqualTo(expected.size());
256-
//
257-
// for (int i = 0; i < expected.size(); i += 1) {
258-
// InternalTestHelpers.assertEquals(SCHEMA.asStruct(), expected.get(i), actual.get(i));
259-
// }
260-
// }
261-
// }
245+
@Test
246+
public void testMixedShredding() throws IOException {
247+
for (Variant v : VARIANTS) {
248+
List<GenericRecord> expected = new ArrayList<>();
249+
for (int i = 0; i < VARIANTS.length; i++) {
250+
expected.add(createRecord(i, VARIANTS[i]));
251+
}
252+
253+
MessageType writeSchema = shreddingSchema(v);
254+
TestSchema testSchema = new TestSchema(writeSchema, readSchema);
255+
256+
List<GenericRecord> actual = writeAndRead(testSchema, expected);
257+
// TODO: CHECK RESULTS
258+
}
259+
}
262260

263261
// Write schema contains the full shredding schema. Read schema should just be a value/metadata pair.
264262
private static class TestSchema {
@@ -336,4 +334,149 @@ private List<GenericRecord> writeAndRead(
336334
}
337335
return result;
338336
}
337+
338+
/**
339+
* Build a shredding schema that will perfectly shred the provided value.
340+
*/
341+
private static MessageType shreddingSchema(Variant v) {
342+
Type shreddedType = shreddedType(v);
343+
Types.GroupBuilder<GroupType> partialType = Types.buildGroup(Type.Repetition.OPTIONAL)
344+
.as(LogicalTypeAnnotation.variantType((byte) 1))
345+
.required(BINARY)
346+
.named("metadata")
347+
.optional(BINARY)
348+
.named("value");
349+
Type variantType;
350+
if (shreddedType == null) {
351+
variantType = partialType.named("var");
352+
} else {
353+
variantType = partialType.addField(shreddedType).named("var");
354+
}
355+
return Types.buildMessage()
356+
.required(INT32)
357+
.named("id")
358+
.addField(variantType)
359+
.named("table");
360+
}
361+
362+
private static GroupType shreddedGroup(Variant v, String name) {
363+
Type shreddedType = shreddedType(v);
364+
if (shreddedType == null) {
365+
return Types.buildGroup(Type.Repetition.OPTIONAL)
366+
.optional(BINARY)
367+
.named("value")
368+
.named(name);
369+
} else {
370+
return Types.buildGroup(Type.Repetition.OPTIONAL)
371+
.optional(BINARY)
372+
.named("value")
373+
.addField(shreddedType)
374+
.named(name);
375+
}
376+
}
377+
378+
/**
379+
* @return A shredded type, or null if there is no valid shredded type.
380+
*/
381+
private static Type shreddedType(Variant v) {
382+
switch (v.getType()) {
383+
case NULL:
384+
return null;
385+
case BOOLEAN:
386+
return Types.optional(BOOLEAN).named("typed_value");
387+
case BYTE:
388+
return Types.optional(INT32).as(LogicalTypeAnnotation.intType(8)).named("typed_value");
389+
case SHORT:
390+
return Types.optional(INT32).as(LogicalTypeAnnotation.intType(16)).named("typed_value");
391+
case INT:
392+
return Types.optional(INT32).named("typed_value");
393+
case LONG:
394+
return Types.optional(INT64).named("typed_value");
395+
case FLOAT:
396+
return Types.optional(FLOAT).named("typed_value");
397+
case DOUBLE:
398+
return Types.optional(DOUBLE).named("typed_value");
399+
case DECIMAL4:
400+
return Types.optional(INT32)
401+
.as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 9))
402+
.named("typed_value");
403+
case DECIMAL8:
404+
return Types.optional(INT64)
405+
.as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 18))
406+
.named("typed_value");
407+
case DECIMAL16:
408+
return Types.optional(BINARY)
409+
.as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 38))
410+
.named("typed_value");
411+
case DATE:
412+
return Types.optional(INT32)
413+
.as(LogicalTypeAnnotation.dateType())
414+
.named("typed_value");
415+
case TIMESTAMP_TZ:
416+
return Types.optional(INT64)
417+
.as(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS))
418+
.named("typed_value");
419+
case TIMESTAMP_NTZ:
420+
return Types.optional(INT64)
421+
.as(LogicalTypeAnnotation.timestampType(false, TimeUnit.MICROS))
422+
.named("typed_value");
423+
case BINARY:
424+
return Types.optional(BINARY).named("typed_value");
425+
case STRING:
426+
return Types.optional(BINARY)
427+
.as(LogicalTypeAnnotation.stringType())
428+
.named("typed_value");
429+
case TIME:
430+
return Types.optional(INT64)
431+
.as(LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS))
432+
.named("typed_value");
433+
case TIMESTAMP_NANOS_TZ:
434+
return Types.optional(INT64)
435+
.as(LogicalTypeAnnotation.timestampType(true, TimeUnit.NANOS))
436+
.named("typed_value");
437+
case TIMESTAMP_NANOS_NTZ:
438+
return Types.optional(INT64)
439+
.as(LogicalTypeAnnotation.timestampType(false, TimeUnit.NANOS))
440+
.named("typed_value");
441+
case UUID:
442+
return Types.optional(FIXED_LEN_BYTE_ARRAY)
443+
.as(LogicalTypeAnnotation.uuidType())
444+
.named("typed_value");
445+
case OBJECT:
446+
return shreddedObjectType(v);
447+
case ARRAY:
448+
return shreddedArrayType(v);
449+
default:
450+
throw new UnsupportedOperationException("Unsupported shredding type: " + v.getType());
451+
}
452+
}
453+
454+
private static Type shreddedObjectType(Variant v) {
455+
if (v.numObjectElements() == 0) {
456+
// Parquet can't represent empty groups.
457+
return null;
458+
}
459+
Types.GroupBuilder<GroupType> builder = Types.optionalGroup();
460+
for (int i = 0; i < v.numObjectElements(); i++) {
461+
Variant.ObjectField field = v.getFieldAtIndex(i);
462+
Types.GroupBuilder<GroupType> fieldBuilder = Types.optionalGroup();
463+
Type fieldType = shreddedGroup(field.value, field.key);
464+
builder.addField(fieldType);
465+
}
466+
return builder.named("typed_value");
467+
}
468+
469+
private static Type shreddedArrayType(Variant v) {
470+
// Use the first element to determine the array element type
471+
Variant firstElement;
472+
if (v.numArrayElements() > 0) {
473+
firstElement = v.getElementAtIndex(0);
474+
} else {
475+
// Use null as a dummy value, which will omit typed_value from the schema.
476+
firstElement = fullVariant(b -> b.appendNull());
477+
}
478+
479+
Type elementType = shreddedGroup(firstElement, "element");
480+
return Types.optionalList().setElementType(elementType).named("typed_value");
481+
}
339482
}

parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ public byte[] valueWithoutMetadata() {
151151
*/
152152
void shallowAppendVariant(Binary value) {
153153
onAppend();
154-
int size = value.length();
155-
checkCapacity(size);
156154
byte[] buf = value.getBytes();
155+
int size = VariantUtil.valueSize(ByteBuffer.wrap(buf), 0);
156+
checkCapacity(size);
157157
System.arraycopy(buf, 0, writeBuffer, writePos, size);
158158
writePos += size;
159159
}

parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,4 +843,73 @@ static HashMap<String, Integer> getMetadataMap(ByteBuffer metadata) {
843843
}
844844
return result;
845845
}
846+
847+
/**
848+
* Computes the actual size (in bytes) of the Variant value at `value[pos...]`
849+
* @param value The Variant value
850+
* @param pos The starting index of the Variant value
851+
* @return The size (in bytes) of the Variant value
852+
*/
853+
public static int valueSize(ByteBuffer value, int pos) {
854+
checkIndex(pos, value.limit());
855+
int basicType = value.get(pos) & BASIC_TYPE_MASK;
856+
int typeInfo = (value.get(pos) >> BASIC_TYPE_BITS) & PRIMITIVE_TYPE_MASK;
857+
switch (basicType) {
858+
case SHORT_STR:
859+
return 1 + typeInfo;
860+
case OBJECT: {
861+
VariantUtil.ObjectInfo info = VariantUtil.getObjectInfo(slice(value, pos));
862+
return info.dataStartOffset
863+
+ readUnsigned(
864+
value,
865+
pos + info.offsetStartOffset + info.numElements * info.offsetSize,
866+
info.offsetSize);
867+
}
868+
case ARRAY: {
869+
VariantUtil.ArrayInfo info = VariantUtil.getArrayInfo(slice(value, pos));
870+
return info.dataStartOffset
871+
+ readUnsigned(
872+
value,
873+
pos + info.offsetStartOffset + info.numElements * info.offsetSize,
874+
info.offsetSize);
875+
}
876+
default:
877+
switch (typeInfo) {
878+
case NULL:
879+
case TRUE:
880+
case FALSE:
881+
return 1;
882+
case INT8:
883+
return 2;
884+
case INT16:
885+
return 3;
886+
case INT32:
887+
case DATE:
888+
case FLOAT:
889+
return 5;
890+
case INT64:
891+
case DOUBLE:
892+
case TIMESTAMP_TZ:
893+
case TIMESTAMP_NTZ:
894+
case TIME:
895+
case TIMESTAMP_NANOS_TZ:
896+
case TIMESTAMP_NANOS_NTZ:
897+
return 9;
898+
case DECIMAL4:
899+
return 6;
900+
case DECIMAL8:
901+
return 10;
902+
case DECIMAL16:
903+
return 18;
904+
case BINARY:
905+
case LONG_STR:
906+
return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE);
907+
case UUID:
908+
return 1 + UUID_SIZE;
909+
default:
910+
throw new UnsupportedOperationException(
911+
String.format("Unknown type in Variant. primitive type: %d", typeInfo));
912+
}
913+
}
914+
}
846915
}

0 commit comments

Comments
 (0)