diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle index 9ed46a367735..cc9eae786b90 100644 --- a/sdks/java/io/parquet/build.gradle +++ b/sdks/java/io/parquet/build.gradle @@ -47,6 +47,7 @@ dependencies { implementation "org.apache.parquet:parquet-column:$parquet_version" implementation "org.apache.parquet:parquet-common:$parquet_version" implementation "org.apache.parquet:parquet-hadoop:$parquet_version" + implementation "org.apache.parquet:parquet-protobuf:$parquet_version" implementation library.java.avro provided library.java.hadoop_client permitUnusedDeclared library.java.hadoop_client @@ -54,6 +55,7 @@ dependencies { testImplementation library.java.hadoop_client testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(path: ":sdks:java:extensions:avro") + testImplementation "org.apache.avro:avro-protobuf:1.10.2" testImplementation library.java.junit testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java index 24c18f382817..0005be840385 100644 --- a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java +++ b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java @@ -64,6 +64,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.parquet.HadoopReadOptions; import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.avro.AvroParquetReader; @@ -73,6 +74,7 @@ import org.apache.parquet.filter2.compat.FilterCompat; import org.apache.parquet.filter2.compat.FilterCompat.Filter; import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetReader; import org.apache.parquet.hadoop.ParquetWriter; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; @@ -89,6 +91,8 @@ import org.apache.parquet.io.RecordReader; import org.apache.parquet.io.SeekableInputStream; import org.apache.parquet.io.api.RecordMaterializer; +import org.apache.parquet.proto.ProtoParquetReader; +import org.apache.parquet.proto.ProtoReadSupport; import org.apache.parquet.schema.MessageType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -244,6 +248,12 @@ public class ParquetIO { private static final Logger LOG = LoggerFactory.getLogger(ParquetIO.class); + // New: Enum for more flexibility in the future + public enum ReaderFormat { + AVRO, + PROTO + } + /** * Reads {@link GenericRecord} from a Parquet file (or multiple Parquet files matching the * pattern). @@ -263,6 +273,7 @@ public static ReadFiles readFiles(Schema schema) { return new AutoValue_ParquetIO_ReadFiles.Builder() .setSchema(schema) .setInferBeamSchema(false) + .setReaderFormat(ReaderFormat.AVRO) // NEW: default to existing AVRO and new PROTO .build(); } @@ -616,6 +627,9 @@ public abstract static class ReadFiles abstract Builder toBuilder(); + // New: enum returned to indicate using AVRO or PROTO for protobuf data. + abstract ReaderFormat getReaderFormat(); + @AutoValue.Builder abstract static class Builder { abstract Builder setSchema(Schema schema); @@ -630,6 +644,8 @@ abstract static class Builder { abstract Builder setInferBeamSchema(boolean inferBeamSchema); + abstract Builder setReaderFormat(ReaderFormat readerFormat); + abstract ReadFiles build(); } @@ -663,6 +679,16 @@ public ReadFiles withBeamSchemas(boolean inferBeamSchema) { return toBuilder().setInferBeamSchema(inferBeamSchema).build(); } + // New: method to opt into using avro + public ReadFiles withAvroReader() { + return toBuilder().setReaderFormat(ReaderFormat.AVRO).build(); + } + + // New: method to opt into using Protobuf + public ReadFiles withProtoReader() { + return toBuilder().setReaderFormat(ReaderFormat.PROTO).build(); + } + @Override public PCollection expand(PCollection input) { checkNotNull(getSchema(), "Schema can not be null"); @@ -673,7 +699,9 @@ public PCollection expand(PCollection input) { getAvroDataModel(), getProjectionSchema(), GenericRecordPassthroughFn.create(), - getConfiguration()))) + getConfiguration(), + getReaderFormat() // New: pass the new enum here + ))) .setCoder(getCollectionCoder()); } @@ -686,7 +714,10 @@ public void populateDisplayData(DisplayData.Builder builder) { DisplayData.item("inferBeamSchema", getInferBeamSchema()) .withLabel("Infer Beam Schema")) .addIfNotNull(DisplayData.item("projectionSchema", String.valueOf(getProjectionSchema()))) - .addIfNotNull(DisplayData.item("avroDataModel", String.valueOf(getAvroDataModel()))); + .addIfNotNull(DisplayData.item("avroDataModel", String.valueOf(getAvroDataModel()))) + .add( + DisplayData.item("readerFormat", getReaderFormat().name()) + .withLabel("Reader format (AVRO|PROTO)")); if (this.getConfiguration() != null) { Configuration configuration = this.getConfiguration().get(); for (Entry entry : configuration) { @@ -718,16 +749,30 @@ static class SplitReadFn extends DoFn { private final SerializableFunction parseFn; + private final ReaderFormat readerFormat; // New: new flag for protobuf or avro + SplitReadFn( GenericData model, Schema requestSchema, SerializableFunction parseFn, - @Nullable SerializableConfiguration configuration) { + @Nullable SerializableConfiguration configuration, + ReaderFormat readerFormat // New: add flag here + ) { this.modelClass = model != null ? model.getClass() : null; this.requestSchemaString = requestSchema != null ? requestSchema.toString() : null; this.parseFn = checkNotNull(parseFn, "GenericRecord parse function can't be null"); this.configuration = configuration; + this.readerFormat = readerFormat; // New: assign the format + } + + // New: Overloaded constructor for backward compatibility: + SplitReadFn( + GenericData model, + Schema requestSchema, + SerializableFunction parseFn, + @Nullable SerializableConfiguration configuration) { + this(model, requestSchema, parseFn, configuration, ReaderFormat.AVRO); } private ParquetFileReader getParquetFileReader(ReadableFile file) throws Exception { @@ -746,96 +791,123 @@ public void processElement( tracker.currentRestriction().getFrom(), tracker.currentRestriction().getTo()); Configuration conf = getConfWithModelClass(); - GenericData model = null; - if (modelClass != null) { - model = (GenericData) modelClass.getMethod("get").invoke(null); - } - AvroReadSupport readSupport = new AvroReadSupport<>(model); - if (requestSchemaString != null) { - AvroReadSupport.setRequestedProjection( - conf, new Schema.Parser().parse(requestSchemaString)); - } - ParquetReadOptions options = HadoopReadOptions.builder(conf).build(); - try (ParquetFileReader reader = - ParquetFileReader.open(new BeamParquetInputFile(file.openSeekable()), options)) { - Filter filter = checkNotNull(options.getRecordFilter(), "filter"); - Configuration hadoopConf = ((HadoopReadOptions) options).getConf(); - FileMetaData parquetFileMetadata = reader.getFooter().getFileMetaData(); - MessageType fileSchema = parquetFileMetadata.getSchema(); - Map fileMetadata = parquetFileMetadata.getKeyValueMetaData(); - ReadSupport.ReadContext readContext = - readSupport.init( - new InitContext( - hadoopConf, - Maps.transformValues(fileMetadata, ImmutableSet::of), - fileSchema)); - ColumnIOFactory columnIOFactory = new ColumnIOFactory(parquetFileMetadata.getCreatedBy()); - - RecordMaterializer recordConverter = - readSupport.prepareForRead(hadoopConf, fileMetadata, fileSchema, readContext); - reader.setRequestedSchema(readContext.getRequestedSchema()); - MessageColumnIO columnIO = - columnIOFactory.getColumnIO(readContext.getRequestedSchema(), fileSchema, true); - long currentBlock = tracker.currentRestriction().getFrom(); - for (int i = 0; i < currentBlock; i++) { - reader.skipNextRowGroup(); - } - while (tracker.tryClaim(currentBlock)) { - PageReadStore pages = reader.readNextRowGroup(); - LOG.debug("block {} read in memory. row count = {}", currentBlock, pages.getRowCount()); - currentBlock += 1; - RecordReader recordReader = - columnIO.getRecordReader( - pages, recordConverter, options.useRecordFilter() ? filter : FilterCompat.NOOP); - long currentRow = 0; - long totalRows = pages.getRowCount(); - while (currentRow < totalRows) { - try { - GenericRecord record; - currentRow += 1; - try { - record = recordReader.read(); - } catch (RecordMaterializer.RecordMaterializationException e) { - LOG.warn( - "skipping a corrupt record at {} in block {} in file {}", - currentRow, - currentBlock, - file.toString()); - continue; - } - if (record == null) { - // it happens when a record is filtered out in this block - LOG.debug( - "record is filtered out by reader in block {} in file {}", - currentBlock, - file.toString()); - continue; - } - if (recordReader.shouldSkipCurrentRecord()) { - // this record is being filtered via the filter2 package - LOG.debug( - "skipping record at {} in block {} in file {}", - currentRow, - currentBlock, - file.toString()); - continue; + + switch (readerFormat) { + case PROTO: + // Use ProtoParquetReader to read protobuf data. + // Derive a Hadoop Path from the file metadata. Adjust as needed. + Path path = new Path(file.getMetadata().resourceId().toString()); + + // Use the builder overload that takes a ReadSupport and a Path. + try (ParquetReader reader = + ProtoParquetReader.builder(new ProtoReadSupport(), path).build()) { + GenericRecord message; + while ((message = reader.read()) != null) { + // Cast through Object so that parseFn (which expects GenericRecord) + // can accept the DynamicMessage. + outputReceiver.output(parseFn.apply((GenericRecord) (Object) message)); + } + } + break; + + case AVRO: + default: + // ELSE: existing logic using ParquetFileReader for Avro + GenericData model = null; + if (modelClass != null) { + model = (GenericData) modelClass.getMethod("get").invoke(null); + } + AvroReadSupport readSupport = new AvroReadSupport<>(model); + if (requestSchemaString != null) { + AvroReadSupport.setRequestedProjection( + conf, new Schema.Parser().parse(requestSchemaString)); + } + ParquetReadOptions options = HadoopReadOptions.builder(conf).build(); + try (ParquetFileReader reader = + ParquetFileReader.open(new BeamParquetInputFile(file.openSeekable()), options)) { + Filter filter = checkNotNull(options.getRecordFilter(), "filter"); + Configuration hadoopConf = ((HadoopReadOptions) options).getConf(); + FileMetaData parquetFileMetadata = reader.getFooter().getFileMetaData(); + MessageType fileSchema = parquetFileMetadata.getSchema(); + Map fileMetadata = parquetFileMetadata.getKeyValueMetaData(); + ReadSupport.ReadContext readContext = + readSupport.init( + new InitContext( + hadoopConf, + Maps.transformValues(fileMetadata, ImmutableSet::of), + fileSchema)); + ColumnIOFactory columnIOFactory = + new ColumnIOFactory(parquetFileMetadata.getCreatedBy()); + + RecordMaterializer recordConverter = + readSupport.prepareForRead(hadoopConf, fileMetadata, fileSchema, readContext); + reader.setRequestedSchema(readContext.getRequestedSchema()); + MessageColumnIO columnIO = + columnIOFactory.getColumnIO(readContext.getRequestedSchema(), fileSchema, true); + long currentBlock = tracker.currentRestriction().getFrom(); + for (int i = 0; i < currentBlock; i++) { + reader.skipNextRowGroup(); + } + while (tracker.tryClaim(currentBlock)) { + PageReadStore pages = reader.readNextRowGroup(); + LOG.debug( + "block {} read in memory. row count = {}", currentBlock, pages.getRowCount()); + currentBlock += 1; + RecordReader recordReader = + columnIO.getRecordReader( + pages, + recordConverter, + options.useRecordFilter() ? filter : FilterCompat.NOOP); + long currentRow = 0; + long totalRows = pages.getRowCount(); + while (currentRow < totalRows) { + try { + GenericRecord record; + currentRow += 1; + try { + record = recordReader.read(); + } catch (RecordMaterializer.RecordMaterializationException e) { + LOG.warn( + "skipping a corrupt record at {} in block {} in file {}", + currentRow, + currentBlock, + file.toString()); + continue; + } + if (record == null) { + // it happens when a record is filtered out in this block + LOG.debug( + "record is filtered out by reader in block {} in file {}", + currentBlock, + file.toString()); + continue; + } + if (recordReader.shouldSkipCurrentRecord()) { + // this record is being filtered via the filter2 package + LOG.debug( + "skipping record at {} in block {} in file {}", + currentRow, + currentBlock, + file.toString()); + continue; + } + outputReceiver.output(parseFn.apply(record)); + } catch (RuntimeException e) { + + throw new ParquetDecodingException( + format( + "Can not read value at %d in block %d in file %s", + currentRow, currentBlock, file.toString()), + e); + } } - outputReceiver.output(parseFn.apply(record)); - } catch (RuntimeException e) { - - throw new ParquetDecodingException( - format( - "Can not read value at %d in block %d in file %s", - currentRow, currentBlock, file.toString()), - e); + LOG.debug( + "Finish processing {} rows from block {} in file {}", + currentRow, + currentBlock - 1, + file.toString()); } } - LOG.debug( - "Finish processing {} rows from block {} in file {}", - currentRow, - currentBlock - 1, - file.toString()); - } } } diff --git a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java index 7ee3ec5050fd..918a617e096e 100644 --- a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java +++ b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java @@ -39,11 +39,13 @@ import org.apache.avro.generic.GenericRecordBuilder; import org.apache.avro.io.EncoderFactory; import org.apache.avro.io.JsonEncoder; +import org.apache.avro.protobuf.ProtobufData; import org.apache.avro.reflect.ReflectData; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn; +import org.apache.beam.sdk.io.parquet.ParquetIO.ReaderFormat; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.testing.PAssert; @@ -137,6 +139,32 @@ public void testWriteAndReadWithProjection() { readPipeline.run().waitUntilFinish(); } + @Test + public void testParquetProtobufReadError() { + ProtobufData protoData = new ProtobufData() {}; + Exception thrown = + assertThrows(RuntimeException.class, () -> protoData.getSchema(GenericData.Record.class)); + assertTrue( + "Error message should mention 'getDescriptor'", + thrown.getMessage().contains("org.apache.avro.generic.GenericData$Record.getDescriptor")); + } + + @Test + public void testReadFilesWithProtoReaderFlag() { + // Create a ReadFiles transform with the proto-reader enabled. + ParquetIO.ReadFiles readFiles = ParquetIO.readFiles(SCHEMA).withProtoReader(); + assertEquals( + "Proto reader flag should be enabled", ReaderFormat.PROTO, readFiles.getReaderFormat()); + } + + @Test + public void testReadFilesDisplayDataWithProtoReader() { + // Create a ReadFiles transform with proto-reader enabled. + ParquetIO.ReadFiles readFiles = ParquetIO.readFiles(SCHEMA).withProtoReader(); + DisplayData displayData = DisplayData.from(readFiles); + assertThat(displayData, hasDisplayItem("readerFormat", ReaderFormat.PROTO.name())); + } + @Test public void testBlockTracker() { OffsetRange range = new OffsetRange(0, 1);