Skip to content

Commit 43397a8

Browse files
committed
fix tableRowFromMessage
1 parent af748d0 commit 43397a8

17 files changed

+1944
-430
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.transforms;
19+
20+
import org.checkerframework.checker.nullness.qual.Nullable;
21+
22+
/** Useful {@link SerializableFunction} overrides. */
23+
public class SerializableBiFunctions {
24+
/** Always returns the first argument. */
25+
public static <FirstInputT, SecondInputT, OutputT>
26+
SerializableBiFunction<FirstInputT, SecondInputT, FirstInputT> select1st(
27+
SerializableBiFunction<@Nullable FirstInputT, @Nullable SecondInputT, OutputT>
28+
biFunction) {
29+
return (t, u) -> t;
30+
}
31+
32+
/** Always returns the second argument. */
33+
public static <FirstInputT, SecondInputT, OutputT>
34+
SerializableBiFunction<FirstInputT, SecondInputT, SecondInputT> select2nd(
35+
SerializableBiFunction<@Nullable FirstInputT, @Nullable SecondInputT, OutputT>
36+
biFunction) {
37+
return (t, u) -> u;
38+
}
39+
40+
/** Convert to a unary function by fixing the first argument. */
41+
public static <FirstInputT, SecondInputT, OutputT>
42+
SerializableFunction<SecondInputT, OutputT> fix1st(
43+
SerializableBiFunction<@Nullable FirstInputT, @Nullable SecondInputT, OutputT> biFunction,
44+
@Nullable FirstInputT value) {
45+
return u -> biFunction.apply(value, u);
46+
}
47+
48+
/** Convert to a unary function by fixing the second argument. */
49+
public static <FirstInputT, SecondInputT, OutputT>
50+
SerializableFunction<FirstInputT, OutputT> fix2nd(
51+
SerializableBiFunction<@Nullable FirstInputT, @Nullable SecondInputT, OutputT> biFunction,
52+
@Nullable SecondInputT value) {
53+
return t -> biFunction.apply(t, value);
54+
}
55+
56+
/** Convert from a unary function by ignoring the first argument. */
57+
public static <FirstInputT, SecondInputT, OutputT>
58+
SerializableBiFunction<FirstInputT, SecondInputT, OutputT> ignore1st(
59+
SerializableFunction<@Nullable SecondInputT, OutputT> function) {
60+
return (t, u) -> function.apply(u);
61+
}
62+
63+
/** Convert from a unary function by ignoring the second argument. */
64+
public static <FirstInputT, SecondInputT, OutputT>
65+
SerializableBiFunction<FirstInputT, SecondInputT, OutputT> ignore2nd(
66+
SerializableFunction<@Nullable FirstInputT, OutputT> function) {
67+
return (t, u) -> function.apply(t);
68+
}
69+
}

sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,51 @@ import "proto3_schema_options.proto";
3333

3434
option java_package = "org.apache.beam.sdk.extensions.protobuf";
3535

36+
message PrimitiveEncodedFields {
37+
int64 encoded_timestamp = 1;
38+
int32 encoded_date = 2;
39+
bytes encoded_numeric = 3;
40+
bytes encoded_bignumeric = 4;
41+
int64 encoded_packed_datetime = 5;
42+
int64 encoded_packed_time = 6;
43+
}
44+
45+
message NestedEncodedFields {
46+
PrimitiveEncodedFields nested = 1;
47+
repeated PrimitiveEncodedFields nested_list = 2;
48+
}
49+
50+
message PrimitiveUnEncodedFields {
51+
string timestamp = 1;
52+
string date = 2;
53+
string numeric = 3;
54+
string bignumeric = 4;
55+
string datetime = 5;
56+
string time = 6;
57+
}
58+
59+
message NestedUnEncodedFields {
60+
PrimitiveUnEncodedFields nested = 1;
61+
repeated PrimitiveUnEncodedFields nested_list = 2;
62+
}
63+
64+
message WrapperUnEncodedFields {
65+
google.protobuf.FloatValue float = 1;
66+
google.protobuf.DoubleValue double = 2;
67+
google.protobuf.BoolValue bool = 3;
68+
google.protobuf.Int32Value int32 = 4;
69+
google.protobuf.Int64Value int64 = 5;
70+
google.protobuf.UInt32Value uint32 = 6;
71+
google.protobuf.UInt64Value uint64 = 7;
72+
google.protobuf.BytesValue bytes = 8;
73+
google.protobuf.Timestamp timestamp = 9;
74+
}
75+
76+
message NestedWrapperUnEncodedFields {
77+
WrapperUnEncodedFields nested = 1;
78+
repeated WrapperUnEncodedFields nested_list = 2;
79+
}
80+
3681
message Primitive {
3782
double primitive_double = 1;
3883
float primitive_float = 2;
@@ -287,4 +332,4 @@ message NoWrapPrimitive {
287332
optional bool bool = 13;
288333
optional string string = 14;
289334
optional bytes bytes = 15;
290-
}
335+
}

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ public ByteString mergeNewFields(
182182
public TableRow toTableRow(ByteString protoBytes, Predicate<String> includeField) {
183183
try {
184184
return TableRowToStorageApiProto.tableRowFromMessage(
185+
getSchemaInformation(),
185186
DynamicMessage.parseFrom(
186187
TableRowToStorageApiProto.wrapDescriptorProto(getDescriptor()), protoBytes),
187188
true,

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest;
4343
import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest;
4444
import com.google.cloud.bigquery.storage.v1.DataFormat;
45+
import com.google.cloud.bigquery.storage.v1.ProtoSchemaConverter;
4546
import com.google.cloud.bigquery.storage.v1.ReadSession;
4647
import com.google.cloud.bigquery.storage.v1.ReadStream;
4748
import com.google.gson.JsonArray;
@@ -119,6 +120,8 @@
119120
import org.apache.beam.sdk.transforms.PTransform;
120121
import org.apache.beam.sdk.transforms.ParDo;
121122
import org.apache.beam.sdk.transforms.Reshuffle;
123+
import org.apache.beam.sdk.transforms.SerializableBiFunction;
124+
import org.apache.beam.sdk.transforms.SerializableBiFunctions;
122125
import org.apache.beam.sdk.transforms.SerializableFunction;
123126
import org.apache.beam.sdk.transforms.SerializableFunctions;
124127
import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -2297,10 +2300,61 @@ public static <T extends Message> Write<T> writeProtos(Class<T> protoMessageClas
22972300
if (DynamicMessage.class.equals(protoMessageClass)) {
22982301
throw new IllegalArgumentException("DynamicMessage is not supported.");
22992302
}
2300-
return BigQueryIO.<T>write()
2301-
.withFormatFunction(
2302-
m -> TableRowToStorageApiProto.tableRowFromMessage(m, false, Predicates.alwaysTrue()))
2303-
.withWriteProtosClass(protoMessageClass);
2303+
try {
2304+
return BigQueryIO.<T>write()
2305+
.toBuilder()
2306+
.setFormatFunction(FormatProto.fromClass(protoMessageClass))
2307+
.build()
2308+
.withWriteProtosClass(protoMessageClass);
2309+
} catch (Exception e) {
2310+
throw new RuntimeException(e);
2311+
}
2312+
}
2313+
2314+
private static class FormatProto<T extends Message>
2315+
implements SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow> {
2316+
transient TableRowToStorageApiProto.SchemaInformation inferredSchemaInformation;
2317+
final Class<T> protoMessageClass;
2318+
2319+
FormatProto(Class<T> protoMessageClass) {
2320+
this.protoMessageClass = protoMessageClass;
2321+
}
2322+
2323+
TableRowToStorageApiProto.SchemaInformation inferSchemaInformation() {
2324+
try {
2325+
if (inferredSchemaInformation == null) {
2326+
Descriptors.Descriptor descriptor =
2327+
(Descriptors.Descriptor)
2328+
org.apache.beam.sdk.util.Preconditions.checkStateNotNull(
2329+
protoMessageClass.getMethod("getDescriptor"))
2330+
.invoke(null);
2331+
Descriptors.Descriptor convertedDescriptor =
2332+
TableRowToStorageApiProto.wrapDescriptorProto(
2333+
ProtoSchemaConverter.convert(descriptor).getProtoDescriptor());
2334+
TableSchema tableSchema =
2335+
TableRowToStorageApiProto.protoSchemaToTableSchema(
2336+
TableRowToStorageApiProto.tableSchemaFromDescriptor(convertedDescriptor));
2337+
this.inferredSchemaInformation =
2338+
TableRowToStorageApiProto.SchemaInformation.fromTableSchema(tableSchema);
2339+
}
2340+
return inferredSchemaInformation;
2341+
} catch (Exception e) {
2342+
throw new RuntimeException(e);
2343+
}
2344+
}
2345+
2346+
static <T extends Message> FormatProto<T> fromClass(Class<T> protoMessageClass)
2347+
throws Exception {
2348+
return new FormatProto<>(protoMessageClass);
2349+
}
2350+
2351+
@Override
2352+
public TableRow apply(TableRowToStorageApiProto.SchemaInformation schemaInformation, T input) {
2353+
TableRowToStorageApiProto.SchemaInformation localSchemaInformation =
2354+
schemaInformation != null ? schemaInformation : inferSchemaInformation();
2355+
return TableRowToStorageApiProto.tableRowFromMessage(
2356+
localSchemaInformation, input, false, Predicates.alwaysTrue());
2357+
}
23042358
}
23052359

23062360
/** Implementation of {@link #write}. */
@@ -2354,9 +2408,13 @@ public enum Method {
23542408
abstract @Nullable SerializableFunction<ValueInSingleWindow<T>, TableDestination>
23552409
getTableFunction();
23562410

2357-
abstract @Nullable SerializableFunction<T, TableRow> getFormatFunction();
2411+
abstract @Nullable SerializableBiFunction<
2412+
TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2413+
getFormatFunction();
23582414

2359-
abstract @Nullable SerializableFunction<T, TableRow> getFormatRecordOnFailureFunction();
2415+
abstract @Nullable SerializableBiFunction<
2416+
TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2417+
getFormatRecordOnFailureFunction();
23602418

23612419
abstract RowWriterFactory.@Nullable AvroRowWriterFactory<T, ?, ?> getAvroRowWriterFactory();
23622420

@@ -2467,10 +2525,13 @@ abstract static class Builder<T> {
24672525
abstract Builder<T> setTableFunction(
24682526
SerializableFunction<ValueInSingleWindow<T>, TableDestination> tableFunction);
24692527

2470-
abstract Builder<T> setFormatFunction(SerializableFunction<T, TableRow> formatFunction);
2528+
abstract Builder<T> setFormatFunction(
2529+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2530+
formatFunction);
24712531

24722532
abstract Builder<T> setFormatRecordOnFailureFunction(
2473-
SerializableFunction<T, TableRow> formatFunction);
2533+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2534+
formatFunction);
24742535

24752536
abstract Builder<T> setAvroRowWriterFactory(
24762537
RowWriterFactory.AvroRowWriterFactory<T, ?, ?> avroRowWriterFactory);
@@ -2718,7 +2779,9 @@ public Write<T> to(DynamicDestinations<T, ?> dynamicDestinations) {
27182779

27192780
/** Formats the user's type into a {@link TableRow} to be written to BigQuery. */
27202781
public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunction) {
2721-
return toBuilder().setFormatFunction(formatFunction).build();
2782+
return toBuilder()
2783+
.setFormatFunction(SerializableBiFunctions.ignore1st(formatFunction))
2784+
.build();
27222785
}
27232786

27242787
/**
@@ -2733,7 +2796,9 @@ public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunct
27332796
*/
27342797
public Write<T> withFormatRecordOnFailureFunction(
27352798
SerializableFunction<T, TableRow> formatFunction) {
2736-
return toBuilder().setFormatRecordOnFailureFunction(formatFunction).build();
2799+
return toBuilder()
2800+
.setFormatRecordOnFailureFunction(SerializableBiFunctions.ignore1st(formatFunction))
2801+
.build();
27372802
}
27382803

27392804
/**
@@ -3599,9 +3664,10 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) {
35993664
private <DestinationT> WriteResult expandTyped(
36003665
PCollection<T> input, DynamicDestinations<T, DestinationT> dynamicDestinations) {
36013666
boolean optimizeWrites = getOptimizeWrites();
3602-
SerializableFunction<T, TableRow> formatFunction = getFormatFunction();
3603-
SerializableFunction<T, TableRow> formatRecordOnFailureFunction =
3604-
getFormatRecordOnFailureFunction();
3667+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
3668+
formatFunction = getFormatFunction();
3669+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
3670+
formatRecordOnFailureFunction = getFormatRecordOnFailureFunction();
36053671
RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT> avroRowWriterFactory =
36063672
(RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT>) getAvroRowWriterFactory();
36073673

@@ -3623,7 +3689,8 @@ private <DestinationT> WriteResult expandTyped(
36233689
// If no format function set, then we will automatically convert the input type to a
36243690
// TableRow.
36253691
// TODO: it would be trivial to convert to avro records here instead.
3626-
formatFunction = BigQueryUtils.toTableRow(input.getToRowFunction());
3692+
formatFunction =
3693+
SerializableBiFunctions.ignore1st(BigQueryUtils.toTableRow(input.getToRowFunction()));
36273694
}
36283695
// Infer the TableSchema from the input Beam schema.
36293696
// TODO: If the user provided a schema, we should use that. There are things that can be
@@ -3769,8 +3836,9 @@ private <DestinationT> WriteResult continueExpandTyped(
37693836
getCreateDisposition(),
37703837
dynamicDestinations,
37713838
elementCoder,
3772-
tableRowWriterFactory.getToRowFn(),
3773-
tableRowWriterFactory.getToFailsafeRowFn())
3839+
SerializableBiFunctions.fix1st(tableRowWriterFactory.getToRowFn(), null),
3840+
SerializableBiFunctions.fix1st(
3841+
tableRowWriterFactory.getToFailsafeRowFn(), null))
37743842
.withInsertRetryPolicy(retryPolicy)
37753843
.withTestServices(getBigQueryServices())
37763844
.withExtendedErrorInfo(getExtendedErrorInfo())

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.apache.beam.sdk.schemas.Schema.FieldType;
5656
import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration;
5757
import org.apache.beam.sdk.transforms.PTransform;
58+
import org.apache.beam.sdk.transforms.SerializableBiFunction;
5859
import org.apache.beam.sdk.transforms.SerializableFunction;
5960
import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
6061
import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
@@ -641,14 +642,15 @@ public Write<?> fromConfigRow(Row configRow, PipelineOptions options) {
641642
if (formatFunctionBytes != null) {
642643
builder =
643644
builder.setFormatFunction(
644-
(SerializableFunction<?, TableRow>) fromByteArray(formatFunctionBytes));
645+
(SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, ?, TableRow>)
646+
fromByteArray(formatFunctionBytes));
645647
}
646648
byte[] formatRecordOnFailureFunctionBytes =
647649
configRow.getBytes("format_record_on_failure_function");
648650
if (formatRecordOnFailureFunctionBytes != null) {
649651
builder =
650652
builder.setFormatRecordOnFailureFunction(
651-
(SerializableFunction<?, TableRow>)
653+
(SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, ?, TableRow>)
652654
fromByteArray(formatRecordOnFailureFunctionBytes));
653655
}
654656
byte[] avroRowWriterFactoryBytes = configRow.getBytes("avro_row_writer_factory");

0 commit comments

Comments
 (0)