4242import com .google .cloud .bigquery .storage .v1 .AppendRowsRequest ;
4343import com .google .cloud .bigquery .storage .v1 .CreateReadSessionRequest ;
4444import com .google .cloud .bigquery .storage .v1 .DataFormat ;
45+ import com .google .cloud .bigquery .storage .v1 .ProtoSchemaConverter ;
4546import com .google .cloud .bigquery .storage .v1 .ReadSession ;
4647import com .google .cloud .bigquery .storage .v1 .ReadStream ;
4748import com .google .gson .JsonArray ;
119120import org .apache .beam .sdk .transforms .PTransform ;
120121import org .apache .beam .sdk .transforms .ParDo ;
121122import org .apache .beam .sdk .transforms .Reshuffle ;
123+ import org .apache .beam .sdk .transforms .SerializableBiFunction ;
122124import org .apache .beam .sdk .transforms .SerializableFunction ;
123125import org .apache .beam .sdk .transforms .SerializableFunctions ;
124126import org .apache .beam .sdk .transforms .SimpleFunction ;
@@ -2297,10 +2299,79 @@ public static <T extends Message> Write<T> writeProtos(Class<T> protoMessageClas
22972299 if (DynamicMessage .class .equals (protoMessageClass )) {
22982300 throw new IllegalArgumentException ("DynamicMessage is not supported." );
22992301 }
2300- return BigQueryIO .<T >write ()
2301- .withFormatFunction (
2302- m -> TableRowToStorageApiProto .tableRowFromMessage (m , false , Predicates .alwaysTrue ()))
2303- .withWriteProtosClass (protoMessageClass );
2302+ try {
2303+ return BigQueryIO .<T >write ()
2304+ .toBuilder ()
2305+ .setFormatFunction (FormatProto .fromClass (protoMessageClass ))
2306+ .build ()
2307+ .withWriteProtosClass (protoMessageClass );
2308+ } catch (Exception e ) {
2309+ throw new RuntimeException (e );
2310+ }
2311+ }
2312+
2313+ abstract static class TableRowFormatFunction <T >
2314+ implements SerializableBiFunction <
2315+ TableRowToStorageApiProto .@ Nullable SchemaInformation , T , TableRow > {
2316+ static <T > TableRowFormatFunction <T > fromSerializableFunction (
2317+ SerializableFunction <T , TableRow > serializableFunction ) {
2318+ return new TableRowFormatFunction <T >() {
2319+ @ Override
2320+ public TableRow apply (
2321+ TableRowToStorageApiProto .@ Nullable SchemaInformation schemaInformation , T t ) {
2322+ return serializableFunction .apply (t );
2323+ }
2324+ };
2325+ }
2326+
2327+ SerializableFunction <T , TableRow > toSerializableFunction () {
2328+ return input -> apply (null , input );
2329+ }
2330+ }
2331+
2332+ private static class FormatProto <T extends Message > extends TableRowFormatFunction <T > {
2333+ transient TableRowToStorageApiProto .SchemaInformation inferredSchemaInformation ;
2334+ final Class <T > protoMessageClass ;
2335+
2336+ FormatProto (Class <T > protoMessageClass ) {
2337+ this .protoMessageClass = protoMessageClass ;
2338+ }
2339+
2340+ TableRowToStorageApiProto .SchemaInformation inferSchemaInformation () {
2341+ try {
2342+ if (inferredSchemaInformation == null ) {
2343+ Descriptors .Descriptor descriptor =
2344+ (Descriptors .Descriptor )
2345+ org .apache .beam .sdk .util .Preconditions .checkStateNotNull (
2346+ protoMessageClass .getMethod ("getDescriptor" ))
2347+ .invoke (null );
2348+ Descriptors .Descriptor convertedDescriptor =
2349+ TableRowToStorageApiProto .wrapDescriptorProto (
2350+ ProtoSchemaConverter .convert (descriptor ).getProtoDescriptor ());
2351+ TableSchema tableSchema =
2352+ TableRowToStorageApiProto .protoSchemaToTableSchema (
2353+ TableRowToStorageApiProto .tableSchemaFromDescriptor (convertedDescriptor ));
2354+ this .inferredSchemaInformation =
2355+ TableRowToStorageApiProto .SchemaInformation .fromTableSchema (tableSchema );
2356+ }
2357+ return inferredSchemaInformation ;
2358+ } catch (Exception e ) {
2359+ throw new RuntimeException (e );
2360+ }
2361+ }
2362+
2363+ static <T extends Message > FormatProto <T > fromClass (Class <T > protoMessageClass )
2364+ throws Exception {
2365+ return new FormatProto <>(protoMessageClass );
2366+ }
2367+
2368+ @ Override
2369+ public TableRow apply (TableRowToStorageApiProto .SchemaInformation schemaInformation , T input ) {
2370+ TableRowToStorageApiProto .SchemaInformation localSchemaInformation =
2371+ schemaInformation != null ? schemaInformation : inferSchemaInformation ();
2372+ return TableRowToStorageApiProto .tableRowFromMessage (
2373+ localSchemaInformation , input , false , Predicates .alwaysTrue ());
2374+ }
23042375 }
23052376
23062377 /** Implementation of {@link #write}. */
@@ -2354,9 +2425,9 @@ public enum Method {
23542425 abstract @ Nullable SerializableFunction <ValueInSingleWindow <T >, TableDestination >
23552426 getTableFunction ();
23562427
2357- abstract @ Nullable SerializableFunction < T , TableRow > getFormatFunction ();
2428+ abstract @ Nullable TableRowFormatFunction < T > getFormatFunction ();
23582429
2359- abstract @ Nullable SerializableFunction < T , TableRow > getFormatRecordOnFailureFunction ();
2430+ abstract @ Nullable TableRowFormatFunction < T > getFormatRecordOnFailureFunction ();
23602431
23612432 abstract RowWriterFactory .@ Nullable AvroRowWriterFactory <T , ?, ?> getAvroRowWriterFactory ();
23622433
@@ -2467,10 +2538,10 @@ abstract static class Builder<T> {
24672538 abstract Builder <T > setTableFunction (
24682539 SerializableFunction <ValueInSingleWindow <T >, TableDestination > tableFunction );
24692540
2470- abstract Builder <T > setFormatFunction (SerializableFunction < T , TableRow > formatFunction );
2541+ abstract Builder <T > setFormatFunction (TableRowFormatFunction < T > formatFunction );
24712542
24722543 abstract Builder <T > setFormatRecordOnFailureFunction (
2473- SerializableFunction < T , TableRow > formatFunction );
2544+ TableRowFormatFunction < T > formatFunction );
24742545
24752546 abstract Builder <T > setAvroRowWriterFactory (
24762547 RowWriterFactory .AvroRowWriterFactory <T , ?, ?> avroRowWriterFactory );
@@ -2718,7 +2789,9 @@ public Write<T> to(DynamicDestinations<T, ?> dynamicDestinations) {
27182789
27192790 /** Formats the user's type into a {@link TableRow} to be written to BigQuery. */
27202791 public Write <T > withFormatFunction (SerializableFunction <T , TableRow > formatFunction ) {
2721- return toBuilder ().setFormatFunction (formatFunction ).build ();
2792+ return toBuilder ()
2793+ .setFormatFunction (TableRowFormatFunction .fromSerializableFunction (formatFunction ))
2794+ .build ();
27222795 }
27232796
27242797 /**
@@ -2733,7 +2806,10 @@ public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunct
27332806 */
27342807 public Write <T > withFormatRecordOnFailureFunction (
27352808 SerializableFunction <T , TableRow > formatFunction ) {
2736- return toBuilder ().setFormatRecordOnFailureFunction (formatFunction ).build ();
2809+ return toBuilder ()
2810+ .setFormatRecordOnFailureFunction (
2811+ TableRowFormatFunction .fromSerializableFunction (formatFunction ))
2812+ .build ();
27372813 }
27382814
27392815 /**
@@ -3599,9 +3675,8 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) {
35993675 private <DestinationT > WriteResult expandTyped (
36003676 PCollection <T > input , DynamicDestinations <T , DestinationT > dynamicDestinations ) {
36013677 boolean optimizeWrites = getOptimizeWrites ();
3602- SerializableFunction <T , TableRow > formatFunction = getFormatFunction ();
3603- SerializableFunction <T , TableRow > formatRecordOnFailureFunction =
3604- getFormatRecordOnFailureFunction ();
3678+ TableRowFormatFunction <T > formatFunction = getFormatFunction ();
3679+ TableRowFormatFunction <T > formatRecordOnFailureFunction = getFormatRecordOnFailureFunction ();
36053680 RowWriterFactory .AvroRowWriterFactory <T , ?, DestinationT > avroRowWriterFactory =
36063681 (RowWriterFactory .AvroRowWriterFactory <T , ?, DestinationT >) getAvroRowWriterFactory ();
36073682
@@ -3623,7 +3698,9 @@ private <DestinationT> WriteResult expandTyped(
36233698 // If no format function set, then we will automatically convert the input type to a
36243699 // TableRow.
36253700 // TODO: it would be trivial to convert to avro records here instead.
3626- formatFunction = BigQueryUtils .toTableRow (input .getToRowFunction ());
3701+ formatFunction =
3702+ TableRowFormatFunction .fromSerializableFunction (
3703+ BigQueryUtils .toTableRow (input .getToRowFunction ()));
36273704 }
36283705 // Infer the TableSchema from the input Beam schema.
36293706 // TODO: If the user provided a schema, we should use that. There are things that can be
@@ -3769,8 +3846,8 @@ private <DestinationT> WriteResult continueExpandTyped(
37693846 getCreateDisposition (),
37703847 dynamicDestinations ,
37713848 elementCoder ,
3772- tableRowWriterFactory .getToRowFn (),
3773- tableRowWriterFactory .getToFailsafeRowFn ())
3849+ tableRowWriterFactory .getToRowFn (). toSerializableFunction () ,
3850+ tableRowWriterFactory .getToFailsafeRowFn (). toSerializableFunction () )
37743851 .withInsertRetryPolicy (retryPolicy )
37753852 .withTestServices (getBigQueryServices ())
37763853 .withExtendedErrorInfo (getExtendedErrorInfo ())
0 commit comments