Skip to content

Commit 4f577a6

Browse files
authored
GH-3149: Enable ParquetAvroReader to handle decimal types for int32/64 (#3306)
1 parent 1818380 commit 4f577a6

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import java.lang.reflect.Constructor;
2222
import java.lang.reflect.InvocationTargetException;
23+
import java.math.BigDecimal;
24+
import java.math.BigInteger;
2325
import java.nio.ByteBuffer;
2426
import org.apache.avro.Schema;
2527
import org.apache.avro.generic.GenericData;
@@ -29,6 +31,7 @@
2931
import org.apache.parquet.io.api.Binary;
3032
import org.apache.parquet.io.api.GroupConverter;
3133
import org.apache.parquet.io.api.PrimitiveConverter;
34+
import org.apache.parquet.schema.LogicalTypeAnnotation;
3235
import org.apache.parquet.schema.PrimitiveStringifier;
3336
import org.apache.parquet.schema.PrimitiveType;
3437

@@ -339,4 +342,36 @@ public String convert(Binary binary) {
339342
return stringifier.stringify(binary);
340343
}
341344
}
345+
346+
static final class FieldDecimalIntConverter extends AvroPrimitiveConverter {
347+
private final int scale;
348+
349+
public FieldDecimalIntConverter(ParentValueContainer parent, PrimitiveType type) {
350+
super(parent);
351+
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
352+
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
353+
this.scale = decimalType.getScale();
354+
}
355+
356+
@Override
357+
public void addInt(int value) {
358+
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
359+
}
360+
}
361+
362+
static final class FieldDecimalLongConverter extends AvroPrimitiveConverter {
363+
private final int scale;
364+
365+
public FieldDecimalLongConverter(ParentValueContainer parent, PrimitiveType type) {
366+
super(parent);
367+
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
368+
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
369+
this.scale = decimalType.getScale();
370+
}
371+
372+
@Override
373+
public void addLong(long value) {
374+
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
375+
}
376+
}
342377
}

parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,14 @@ private static Converter newConverter(
337337
return newConverter(schema, type, model, null, setter, validator);
338338
}
339339

340+
private static boolean isDecimalType(Type type) {
341+
if (!type.isPrimitive()) {
342+
return false;
343+
}
344+
LogicalTypeAnnotation annotation = type.getLogicalTypeAnnotation();
345+
return annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
346+
}
347+
340348
private static Converter newConverter(
341349
Schema schema,
342350
Type type,
@@ -359,6 +367,9 @@ private static Converter newConverter(
359367
case BOOLEAN:
360368
return new AvroConverters.FieldBooleanConverter(parent);
361369
case INT:
370+
if (isDecimalType(type)) {
371+
return new AvroConverters.FieldDecimalIntConverter(parent, type.asPrimitiveType());
372+
}
362373
Class<?> intDatumClass = getDatumClass(conversion, knownClass, schema, model);
363374
if (intDatumClass == null) {
364375
return new AvroConverters.FieldIntegerConverter(parent);
@@ -374,6 +385,9 @@ private static Converter newConverter(
374385
}
375386
return new AvroConverters.FieldIntegerConverter(parent);
376387
case LONG:
388+
if (isDecimalType(type)) {
389+
return new AvroConverters.FieldDecimalLongConverter(parent, type.asPrimitiveType());
390+
}
377391
return new AvroConverters.FieldLongConverter(parent);
378392
case FLOAT:
379393
return new AvroConverters.FieldFloatConverter(parent);

parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
package org.apache.parquet.avro;
2020

2121
import static org.apache.parquet.avro.AvroTestUtil.optional;
22+
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
23+
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
24+
import static org.apache.parquet.schema.Type.Repetition.REQUIRED;
2225
import static org.junit.Assert.assertEquals;
2326
import static org.junit.Assert.assertNotNull;
2427

@@ -61,17 +64,23 @@
6164
import org.apache.parquet.conf.ParquetConfiguration;
6265
import org.apache.parquet.conf.PlainParquetConfiguration;
6366
import org.apache.parquet.example.data.Group;
67+
import org.apache.parquet.example.data.GroupFactory;
68+
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
6469
import org.apache.parquet.hadoop.ParquetReader;
6570
import org.apache.parquet.hadoop.ParquetWriter;
6671
import org.apache.parquet.hadoop.api.WriteSupport;
72+
import org.apache.parquet.hadoop.example.ExampleParquetWriter;
6773
import org.apache.parquet.hadoop.example.GroupReadSupport;
6874
import org.apache.parquet.hadoop.util.HadoopCodecs;
6975
import org.apache.parquet.io.InputFile;
7076
import org.apache.parquet.io.LocalInputFile;
7177
import org.apache.parquet.io.LocalOutputFile;
7278
import org.apache.parquet.io.api.Binary;
7379
import org.apache.parquet.io.api.RecordConsumer;
80+
import org.apache.parquet.schema.LogicalTypeAnnotation;
81+
import org.apache.parquet.schema.MessageType;
7482
import org.apache.parquet.schema.MessageTypeParser;
83+
import org.apache.parquet.schema.PrimitiveType;
7584
import org.junit.Assert;
7685
import org.junit.Rule;
7786
import org.junit.Test;
@@ -400,6 +409,68 @@ public void testFixedDecimalValues() throws Exception {
400409
Assert.assertEquals("Content should match", expected, records);
401410
}
402411

412+
@Test
413+
public void testDecimalIntegerValues() throws Exception {
414+
415+
File file = temp.newFile("test_decimal_integer_values.parquet");
416+
file.delete();
417+
Path path = new Path(file.toString());
418+
419+
MessageType parquetSchema = new MessageType(
420+
"test_decimal_integer_values",
421+
new PrimitiveType(REQUIRED, INT32, "decimal_age")
422+
.withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(2, 5)),
423+
new PrimitiveType(REQUIRED, INT64, "decimal_salary")
424+
.withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(1, 10)));
425+
426+
try (ParquetWriter<Group> writer =
427+
ExampleParquetWriter.builder(path).withType(parquetSchema).build()) {
428+
429+
GroupFactory factory = new SimpleGroupFactory(parquetSchema);
430+
431+
Group group1 = factory.newGroup();
432+
group1.add("decimal_age", 2534);
433+
group1.add("decimal_salary", 234L);
434+
writer.write(group1);
435+
436+
Group group2 = factory.newGroup();
437+
group2.add("decimal_age", 4267);
438+
group2.add("decimal_salary", 1203L);
439+
writer.write(group2);
440+
}
441+
442+
GenericData decimalSupport = new GenericData();
443+
decimalSupport.addLogicalTypeConversion(new Conversions.DecimalConversion());
444+
445+
List<GenericRecord> records = Lists.newArrayList();
446+
try (ParquetReader<GenericRecord> reader = AvroParquetReader.<GenericRecord>builder(path)
447+
.withDataModel(decimalSupport)
448+
.build()) {
449+
GenericRecord rec;
450+
while ((rec = reader.read()) != null) {
451+
records.add(rec);
452+
}
453+
}
454+
455+
Assert.assertEquals("Should read 2 records", 2, records.size());
456+
457+
// INT32 values
458+
Object firstAge = records.get(0).get("decimal_age");
459+
Object secondAge = records.get(1).get("decimal_age");
460+
461+
Assert.assertTrue("Should be BigDecimal, but is " + firstAge.getClass(), firstAge instanceof BigDecimal);
462+
Assert.assertEquals("Should be 25.34, but is " + firstAge, new BigDecimal("25.34"), firstAge);
463+
Assert.assertEquals("Should be 42.67, but is " + secondAge, new BigDecimal("42.67"), secondAge);
464+
465+
// INT64 values
466+
Object firstSalary = records.get(0).get("decimal_salary");
467+
Object secondSalary = records.get(1).get("decimal_salary");
468+
469+
Assert.assertTrue("Should be BigDecimal, but is " + firstSalary.getClass(), firstSalary instanceof BigDecimal);
470+
Assert.assertEquals("Should be 23.4, but is " + firstSalary, new BigDecimal("23.4"), firstSalary);
471+
Assert.assertEquals("Should be 120.3, but is " + secondSalary, new BigDecimal("120.3"), secondSalary);
472+
}
473+
403474
@Test
404475
public void testAll() throws Exception {
405476
Schema schema =

0 commit comments

Comments
 (0)