Skip to content

Commit 96e79cb

Browse files
authored
Concat protos in BQStorageWriteAPI - solve edge cases during mering of nested repeated fields (#34436)
* concat unknown fields to proto - solve edge cases. * refactoring * spotless
1 parent b687870 commit 96e79cb

File tree

6 files changed

+317
-19
lines changed

6 files changed

+317
-19
lines changed

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,18 @@ Descriptors.Descriptor getDescriptorIgnoreRequired() {
167167
}
168168
}
169169

170+
public ByteString mergeNewFields(
171+
ByteString payloadBytes, TableRow unknownFields, boolean ignoreUnknownValues)
172+
throws TableRowToStorageApiProto.SchemaConversionException {
173+
return TableRowToStorageApiProto.mergeNewFields(
174+
payloadBytes,
175+
getDescriptor(),
176+
getTableSchema(),
177+
getSchemaInformation(),
178+
unknownFields,
179+
ignoreUnknownValues);
180+
}
181+
170182
public TableRow toTableRow(ByteString protoBytes, Predicate<String> includeField) {
171183
try {
172184
return TableRowToStorageApiProto.tableRowFromMessage(

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,38 +49,34 @@ abstract static class Value {
4949
abstract List<@Nullable TableRow> getFailsafeTableRows();
5050
}
5151

52-
interface ConvertUnknownFields {
53-
ByteString convert(TableRow tableRow, boolean ignoreUnknownValues)
52+
interface ConcatFields {
53+
ByteString concat(ByteString bytes, TableRow tableRows)
5454
throws TableRowToStorageApiProto.SchemaConversionException;
5555
}
5656

5757
private final Iterable<StorageApiWritePayload> underlying;
5858
private final long splitSize;
5959

60-
private final ConvertUnknownFields unknownFieldsToMessage;
60+
private final ConcatFields concatProtoAndTableRow;
6161
private final Function<ByteString, TableRow> protoToTableRow;
6262
private final BiConsumer<TimestampedValue<TableRow>, String> failedRowsConsumer;
6363
private final boolean autoUpdateSchema;
64-
private final boolean ignoreUnknownValues;
65-
6664
private final Instant elementsTimestamp;
6765

6866
public SplittingIterable(
6967
Iterable<StorageApiWritePayload> underlying,
7068
long splitSize,
71-
ConvertUnknownFields unknownFieldsToMessage,
69+
ConcatFields concatProtoAndTableRow,
7270
Function<ByteString, TableRow> protoToTableRow,
7371
BiConsumer<TimestampedValue<TableRow>, String> failedRowsConsumer,
7472
boolean autoUpdateSchema,
75-
boolean ignoreUnknownValues,
7673
Instant elementsTimestamp) {
7774
this.underlying = underlying;
7875
this.splitSize = splitSize;
79-
this.unknownFieldsToMessage = unknownFieldsToMessage;
76+
this.concatProtoAndTableRow = concatProtoAndTableRow;
8077
this.protoToTableRow = protoToTableRow;
8178
this.failedRowsConsumer = failedRowsConsumer;
8279
this.autoUpdateSchema = autoUpdateSchema;
83-
this.ignoreUnknownValues = ignoreUnknownValues;
8480
this.elementsTimestamp = elementsTimestamp;
8581
}
8682

@@ -128,10 +124,9 @@ public Value next() {
128124
// Protocol buffer serialization format supports concatenation. We serialize any new
129125
// "known" fields
130126
// into a proto and concatenate to the existing proto.
127+
131128
try {
132-
byteString =
133-
byteString.concat(
134-
unknownFieldsToMessage.convert(unknownFields, ignoreUnknownValues));
129+
byteString = concatProtoAndTableRow.concat(byteString, unknownFields);
135130
} catch (TableRowToStorageApiProto.SchemaConversionException e) {
136131
// This generally implies that ignoreUnknownValues=false and there were still
137132
// unknown values here.

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,12 @@ void addMessage(
585585
}
586586
@Nullable TableRow unknownFields = payload.getUnknownFields();
587587
if (unknownFields != null && !unknownFields.isEmpty()) {
588+
// check if unknownFields contains repeated struct, merge
589+
// otherwise use concat
588590
try {
589-
// TODO(34145, radoslaws): concat will work for unknownFields that are primitive type,
590-
// will cause issues with nested and repeated fields
591591
payloadBytes =
592-
payloadBytes.concat(
593-
Preconditions.checkStateNotNull(appendClientInfo)
594-
.encodeUnknownFields(unknownFields, ignoreUnknownValues));
592+
Preconditions.checkStateNotNull(appendClientInfo)
593+
.mergeNewFields(payloadBytes, unknownFields, ignoreUnknownValues);
595594
} catch (TableRowToStorageApiProto.SchemaConversionException e) {
596595
@Nullable TableRow tableRow = payload.getFailsafeTableRow();
597596
if (tableRow == null) {

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,8 @@ public void process(
613613
new SplittingIterable(
614614
element.getValue(),
615615
splitSize,
616-
(fields, ignore) -> appendClientInfo.get().encodeUnknownFields(fields, ignore),
616+
(bytes, tableRow) ->
617+
appendClientInfo.get().mergeNewFields(bytes, tableRow, ignoreUnknownValues),
617618
bytes -> appendClientInfo.get().toTableRow(bytes, Predicates.alwaysTrue()),
618619
(failedRow, errorMessage) -> {
619620
o.get(failedRowsTag)
@@ -628,7 +629,6 @@ public void process(
628629
.inc(1);
629630
},
630631
autoUpdateSchema,
631-
ignoreUnknownValues,
632632
elementTs);
633633

634634
// Initialize stream names and offsets for all contexts. This will be called initially, but

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import com.google.protobuf.Descriptors.FieldDescriptor;
4040
import com.google.protobuf.Descriptors.FileDescriptor;
4141
import com.google.protobuf.DynamicMessage;
42+
import com.google.protobuf.InvalidProtocolBufferException;
4243
import com.google.protobuf.Message;
4344
import java.math.BigDecimal;
4445
import java.math.BigInteger;
@@ -65,6 +66,8 @@
6566
import java.util.stream.StreamSupport;
6667
import org.apache.beam.sdk.util.Preconditions;
6768
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
69+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Functions;
70+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates;
6871
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
6972
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
7073
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
@@ -868,6 +871,157 @@ private static void fieldDescriptorFromTableField(
868871
descriptorBuilder.addField(fieldDescriptorBuilder.build());
869872
}
870873

874+
/**
875+
* mergeNewFields(original, newFields) unlike proto merge or concatenating proto bytes is merging
876+
* the main differences is skipping primitive fields that are already set and merging structs and
877+
* lists recursively. Method mutates input.
878+
*
879+
* @param original original table row
880+
* @param newRow
881+
* @return merged table row
882+
*/
883+
private static TableRow mergeNewFields(TableRow original, TableRow newRow) {
884+
if (original == null) {
885+
return newRow;
886+
}
887+
if (newRow == null) {
888+
return original;
889+
}
890+
891+
for (Map.Entry<String, Object> entry : newRow.entrySet()) {
892+
String key = entry.getKey();
893+
Object value2 = entry.getValue();
894+
Object value1 = original.get(key);
895+
896+
if (value1 == null) {
897+
original.set(key, value2);
898+
} else {
899+
if (value1 instanceof List && value2 instanceof List) {
900+
List<?> list1 = (List<?>) value1;
901+
List<?> list2 = (List<?>) value2;
902+
if (!list1.isEmpty()
903+
&& list1.get(0) instanceof TableRow
904+
&& !list2.isEmpty()
905+
&& list2.get(0) instanceof TableRow) {
906+
original.set(key, mergeRepeatedStructs((List<TableRow>) list1, (List<TableRow>) list2));
907+
} else {
908+
// primitive lists
909+
original.set(key, value2);
910+
}
911+
} else if (value1 instanceof TableRow && value2 instanceof TableRow) {
912+
original.set(key, mergeNewFields((TableRow) value1, (TableRow) value2));
913+
}
914+
}
915+
}
916+
917+
return original;
918+
}
919+
920+
private static List<TableRow> mergeRepeatedStructs(List<TableRow> list1, List<TableRow> list2) {
921+
List<TableRow> mergedList = new ArrayList<>();
922+
int length = Math.min(list1.size(), list2.size());
923+
924+
for (int i = 0; i < length; i++) {
925+
TableRow orig = (i < list1.size()) ? list1.get(i) : null;
926+
TableRow delta = (i < list2.size()) ? list2.get(i) : null;
927+
// fail if any is shorter
928+
Preconditions.checkArgumentNotNull(orig);
929+
Preconditions.checkArgumentNotNull(delta);
930+
931+
mergedList.add(mergeNewFields(orig, delta));
932+
}
933+
return mergedList;
934+
}
935+
936+
public static ByteString mergeNewFields(
937+
ByteString tableRowProto,
938+
DescriptorProtos.DescriptorProto descriptorProto,
939+
TableSchema tableSchema,
940+
SchemaInformation schemaInformation,
941+
TableRow unknownFields,
942+
boolean ignoreUnknownValues)
943+
throws TableRowToStorageApiProto.SchemaConversionException {
944+
if (unknownFields == null || unknownFields.isEmpty()) {
945+
// nothing to do here
946+
return tableRowProto;
947+
}
948+
// check if unknownFields contains repeated struct, merge
949+
boolean hasRepeatedStruct =
950+
unknownFields.entrySet().stream()
951+
.anyMatch(
952+
entry ->
953+
entry.getValue() instanceof List
954+
&& !((List<?>) entry.getValue()).isEmpty()
955+
&& ((List<?>) entry.getValue()).get(0) instanceof TableRow);
956+
if (!hasRepeatedStruct) {
957+
Descriptor descriptorIgnoreRequired = null;
958+
try {
959+
descriptorIgnoreRequired =
960+
TableRowToStorageApiProto.getDescriptorFromTableSchema(tableSchema, false, false);
961+
} catch (DescriptorValidationException e) {
962+
throw new RuntimeException(e);
963+
}
964+
ByteString unknownFieldsProto =
965+
messageFromTableRow(
966+
schemaInformation,
967+
descriptorIgnoreRequired,
968+
unknownFields,
969+
ignoreUnknownValues,
970+
true,
971+
null,
972+
null,
973+
null)
974+
.toByteString();
975+
return tableRowProto.concat(unknownFieldsProto);
976+
}
977+
978+
DynamicMessage message = null;
979+
Descriptor descriptor = null;
980+
try {
981+
descriptor = wrapDescriptorProto(descriptorProto);
982+
} catch (DescriptorValidationException e) {
983+
throw new RuntimeException(e);
984+
}
985+
try {
986+
message = DynamicMessage.parseFrom(descriptor, tableRowProto);
987+
} catch (InvalidProtocolBufferException e) {
988+
throw new RuntimeException(e);
989+
}
990+
TableRow original =
991+
TableRowToStorageApiProto.tableRowFromMessage(message, true, Predicates.alwaysTrue());
992+
Map<String, Descriptors.FieldDescriptor> fieldDescriptors =
993+
descriptor.getFields().stream()
994+
.collect(Collectors.toMap(Descriptors.FieldDescriptor::getName, Functions.identity()));
995+
// recover cdc data
996+
String cdcType = null;
997+
String sequence = null;
998+
if (fieldDescriptors.get(StorageApiCDC.CHANGE_TYPE_COLUMN) != null
999+
&& fieldDescriptors.get(StorageApiCDC.CHANGE_SQN_COLUMN) != null) {
1000+
cdcType =
1001+
(String)
1002+
message.getField(
1003+
Preconditions.checkStateNotNull(
1004+
fieldDescriptors.get(StorageApiCDC.CHANGE_TYPE_COLUMN)));
1005+
sequence =
1006+
(String)
1007+
message.getField(
1008+
Preconditions.checkStateNotNull(
1009+
fieldDescriptors.get(StorageApiCDC.CHANGE_SQN_COLUMN)));
1010+
}
1011+
TableRow merged = TableRowToStorageApiProto.mergeNewFields(original, unknownFields);
1012+
DynamicMessage dynamicMessage =
1013+
TableRowToStorageApiProto.messageFromTableRow(
1014+
schemaInformation,
1015+
descriptor,
1016+
merged,
1017+
ignoreUnknownValues,
1018+
false,
1019+
null,
1020+
cdcType,
1021+
sequence);
1022+
return dynamicMessage.toByteString();
1023+
}
1024+
8711025
private static @Nullable Object messageValueFromFieldValue(
8721026
SchemaInformation schemaInformation,
8731027
FieldDescriptor fieldDescriptor,

0 commit comments

Comments
 (0)