1919package org .apache .parquet .avro ;
2020
2121import static org .apache .parquet .avro .AvroTestUtil .*;
22+ import static org .apache .parquet .schema .PrimitiveType .PrimitiveTypeName .*;
2223import static org .junit .Assert .*;
2324
2425import java .io .File ;
@@ -99,7 +100,7 @@ private static ByteBuffer variant(String s) {
99100
100101 private static MessageType parquetSchema (GroupType variantGroup ) {
101102 return Types .buildMessage ()
102- .required (PrimitiveTypeName . INT32 )
103+ .required (INT32 )
103104 .named ("id" )
104105 .addField (variantGroup )
105106 .named ("table" );
@@ -231,7 +232,7 @@ public void testUnshreddedValues() throws IOException {
231232 public void testShreddedValues () throws IOException {
232233 for (Variant v : VARIANTS ) {
233234 GenericRecord record = createRecord (1 , v );
234- MessageType writeSchema = shreddingFromValue (v );
235+ MessageType writeSchema = shreddingSchema (v );
235236 TestSchema testSchema = new TestSchema (writeSchema , readSchema );
236237
237238 GenericRecord actual = writeAndRead (testSchema , record );
@@ -241,24 +242,21 @@ public void testShreddedValues() throws IOException {
241242 }
242243 }
243244
244- // @Test
245- // public void testMixedShredding() throws IOException {
246- // for (Variant v : VARIANTS) {
247- // List<Record> expected =
248- // IntStream.range(0, VARIANTS.length)
249- // .mapToObj(i -> RECORD.copy("id", i, "var", VARIANTS[i]))
250- // .collect(Collectors.toList());
251- //
252- // List<Record> actual =
253- // writeAndRead((id, name) -> ParquetVariantUtil.toParquetSchema(v.value()), expected);
254- //
255- // assertThat(actual.size()).isEqualTo(expected.size());
256- //
257- // for (int i = 0; i < expected.size(); i += 1) {
258- // InternalTestHelpers.assertEquals(SCHEMA.asStruct(), expected.get(i), actual.get(i));
259- // }
260- // }
261- // }
245+ @ Test
246+ public void testMixedShredding () throws IOException {
247+ for (Variant v : VARIANTS ) {
248+ List <GenericRecord > expected = new ArrayList <>();
249+ for (int i = 0 ; i < VARIANTS .length ; i ++) {
250+ expected .add (createRecord (i , VARIANTS [i ]));
251+ }
252+
253+ MessageType writeSchema = shreddingSchema (v );
254+ TestSchema testSchema = new TestSchema (writeSchema , readSchema );
255+
256+ List <GenericRecord > actual = writeAndRead (testSchema , expected );
257+ // TODO: CHECK RESULTS
258+ }
259+ }
262260
263261 // Write schema contains the full shredding schema. Read schema should just be a value/metadata pair.
264262 private static class TestSchema {
@@ -336,4 +334,149 @@ private List<GenericRecord> writeAndRead(
336334 }
337335 return result ;
338336 }
337+
338+ /**
339+ * Build a shredding schema that will perfectly shred the provided value.
340+ */
341+ private static MessageType shreddingSchema (Variant v ) {
342+ Type shreddedType = shreddedType (v );
343+ Types .GroupBuilder <GroupType > partialType = Types .buildGroup (Type .Repetition .OPTIONAL )
344+ .as (LogicalTypeAnnotation .variantType ((byte ) 1 ))
345+ .required (BINARY )
346+ .named ("metadata" )
347+ .optional (BINARY )
348+ .named ("value" );
349+ Type variantType ;
350+ if (shreddedType == null ) {
351+ variantType = partialType .named ("var" );
352+ } else {
353+ variantType = partialType .addField (shreddedType ).named ("var" );
354+ }
355+ return Types .buildMessage ()
356+ .required (INT32 )
357+ .named ("id" )
358+ .addField (variantType )
359+ .named ("table" );
360+ }
361+
362+ private static GroupType shreddedGroup (Variant v , String name ) {
363+ Type shreddedType = shreddedType (v );
364+ if (shreddedType == null ) {
365+ return Types .buildGroup (Type .Repetition .OPTIONAL )
366+ .optional (BINARY )
367+ .named ("value" )
368+ .named (name );
369+ } else {
370+ return Types .buildGroup (Type .Repetition .OPTIONAL )
371+ .optional (BINARY )
372+ .named ("value" )
373+ .addField (shreddedType )
374+ .named (name );
375+ }
376+ }
377+
378+ /**
379+ * @return A shredded type, or null if there is no valid shredded type.
380+ */
381+ private static Type shreddedType (Variant v ) {
382+ switch (v .getType ()) {
383+ case NULL :
384+ return null ;
385+ case BOOLEAN :
386+ return Types .optional (BOOLEAN ).named ("typed_value" );
387+ case BYTE :
388+ return Types .optional (INT32 ).as (LogicalTypeAnnotation .intType (8 )).named ("typed_value" );
389+ case SHORT :
390+ return Types .optional (INT32 ).as (LogicalTypeAnnotation .intType (16 )).named ("typed_value" );
391+ case INT :
392+ return Types .optional (INT32 ).named ("typed_value" );
393+ case LONG :
394+ return Types .optional (INT64 ).named ("typed_value" );
395+ case FLOAT :
396+ return Types .optional (FLOAT ).named ("typed_value" );
397+ case DOUBLE :
398+ return Types .optional (DOUBLE ).named ("typed_value" );
399+ case DECIMAL4 :
400+ return Types .optional (INT32 )
401+ .as (LogicalTypeAnnotation .decimalType (v .getDecimal ().scale (), 9 ))
402+ .named ("typed_value" );
403+ case DECIMAL8 :
404+ return Types .optional (INT64 )
405+ .as (LogicalTypeAnnotation .decimalType (v .getDecimal ().scale (), 18 ))
406+ .named ("typed_value" );
407+ case DECIMAL16 :
408+ return Types .optional (BINARY )
409+ .as (LogicalTypeAnnotation .decimalType (v .getDecimal ().scale (), 38 ))
410+ .named ("typed_value" );
411+ case DATE :
412+ return Types .optional (INT32 )
413+ .as (LogicalTypeAnnotation .dateType ())
414+ .named ("typed_value" );
415+ case TIMESTAMP_TZ :
416+ return Types .optional (INT64 )
417+ .as (LogicalTypeAnnotation .timestampType (true , TimeUnit .MICROS ))
418+ .named ("typed_value" );
419+ case TIMESTAMP_NTZ :
420+ return Types .optional (INT64 )
421+ .as (LogicalTypeAnnotation .timestampType (false , TimeUnit .MICROS ))
422+ .named ("typed_value" );
423+ case BINARY :
424+ return Types .optional (BINARY ).named ("typed_value" );
425+ case STRING :
426+ return Types .optional (BINARY )
427+ .as (LogicalTypeAnnotation .stringType ())
428+ .named ("typed_value" );
429+ case TIME :
430+ return Types .optional (INT64 )
431+ .as (LogicalTypeAnnotation .timeType (false , TimeUnit .MICROS ))
432+ .named ("typed_value" );
433+ case TIMESTAMP_NANOS_TZ :
434+ return Types .optional (INT64 )
435+ .as (LogicalTypeAnnotation .timestampType (true , TimeUnit .NANOS ))
436+ .named ("typed_value" );
437+ case TIMESTAMP_NANOS_NTZ :
438+ return Types .optional (INT64 )
439+ .as (LogicalTypeAnnotation .timestampType (false , TimeUnit .NANOS ))
440+ .named ("typed_value" );
441+ case UUID :
442+ return Types .optional (FIXED_LEN_BYTE_ARRAY )
443+ .as (LogicalTypeAnnotation .uuidType ())
444+ .named ("typed_value" );
445+ case OBJECT :
446+ return shreddedObjectType (v );
447+ case ARRAY :
448+ return shreddedArrayType (v );
449+ default :
450+ throw new UnsupportedOperationException ("Unsupported shredding type: " + v .getType ());
451+ }
452+ }
453+
454+ private static Type shreddedObjectType (Variant v ) {
455+ if (v .numObjectElements () == 0 ) {
456+ // Parquet can't represent empty groups.
457+ return null ;
458+ }
459+ Types .GroupBuilder <GroupType > builder = Types .optionalGroup ();
460+ for (int i = 0 ; i < v .numObjectElements (); i ++) {
461+ Variant .ObjectField field = v .getFieldAtIndex (i );
462+ Types .GroupBuilder <GroupType > fieldBuilder = Types .optionalGroup ();
463+ Type fieldType = shreddedGroup (field .value , field .key );
464+ builder .addField (fieldType );
465+ }
466+ return builder .named ("typed_value" );
467+ }
468+
469+ private static Type shreddedArrayType (Variant v ) {
470+ // Use the first element to determine the array element type
471+ Variant firstElement ;
472+ if (v .numArrayElements () > 0 ) {
473+ firstElement = v .getElementAtIndex (0 );
474+ } else {
475+ // Use null as a dummy value, which will omit typed_value from the schema.
476+ firstElement = fullVariant (b -> b .appendNull ());
477+ }
478+
479+ Type elementType = shreddedGroup (firstElement , "element" );
480+ return Types .optionalList ().setElementType (elementType ).named ("typed_value" );
481+ }
339482}
0 commit comments