|
39 | 39 | import com.google.protobuf.Descriptors.FieldDescriptor; |
40 | 40 | import com.google.protobuf.Descriptors.FileDescriptor; |
41 | 41 | import com.google.protobuf.DynamicMessage; |
| 42 | +import com.google.protobuf.InvalidProtocolBufferException; |
42 | 43 | import com.google.protobuf.Message; |
43 | 44 | import java.math.BigDecimal; |
44 | 45 | import java.math.BigInteger; |
|
65 | 66 | import java.util.stream.StreamSupport; |
66 | 67 | import org.apache.beam.sdk.util.Preconditions; |
67 | 68 | import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; |
| 69 | +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Functions; |
| 70 | +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates; |
68 | 71 | import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; |
69 | 72 | import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; |
70 | 73 | import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; |
@@ -868,6 +871,157 @@ private static void fieldDescriptorFromTableField( |
868 | 871 | descriptorBuilder.addField(fieldDescriptorBuilder.build()); |
869 | 872 | } |
870 | 873 |
|
| 874 | + /** |
| 875 | + * mergeNewFields(original, newFields) unlike proto merge or concatenating proto bytes is merging |
| 876 | + * the main differences is skipping primitive fields that are already set and merging structs and |
| 877 | + * lists recursively. Method mutates input. |
| 878 | + * |
| 879 | + * @param original original table row |
| 880 | + * @param newRow |
| 881 | + * @return merged table row |
| 882 | + */ |
| 883 | + private static TableRow mergeNewFields(TableRow original, TableRow newRow) { |
| 884 | + if (original == null) { |
| 885 | + return newRow; |
| 886 | + } |
| 887 | + if (newRow == null) { |
| 888 | + return original; |
| 889 | + } |
| 890 | + |
| 891 | + for (Map.Entry<String, Object> entry : newRow.entrySet()) { |
| 892 | + String key = entry.getKey(); |
| 893 | + Object value2 = entry.getValue(); |
| 894 | + Object value1 = original.get(key); |
| 895 | + |
| 896 | + if (value1 == null) { |
| 897 | + original.set(key, value2); |
| 898 | + } else { |
| 899 | + if (value1 instanceof List && value2 instanceof List) { |
| 900 | + List<?> list1 = (List<?>) value1; |
| 901 | + List<?> list2 = (List<?>) value2; |
| 902 | + if (!list1.isEmpty() |
| 903 | + && list1.get(0) instanceof TableRow |
| 904 | + && !list2.isEmpty() |
| 905 | + && list2.get(0) instanceof TableRow) { |
| 906 | + original.set(key, mergeRepeatedStructs((List<TableRow>) list1, (List<TableRow>) list2)); |
| 907 | + } else { |
| 908 | + // primitive lists |
| 909 | + original.set(key, value2); |
| 910 | + } |
| 911 | + } else if (value1 instanceof TableRow && value2 instanceof TableRow) { |
| 912 | + original.set(key, mergeNewFields((TableRow) value1, (TableRow) value2)); |
| 913 | + } |
| 914 | + } |
| 915 | + } |
| 916 | + |
| 917 | + return original; |
| 918 | + } |
| 919 | + |
| 920 | + private static List<TableRow> mergeRepeatedStructs(List<TableRow> list1, List<TableRow> list2) { |
| 921 | + List<TableRow> mergedList = new ArrayList<>(); |
| 922 | + int length = Math.min(list1.size(), list2.size()); |
| 923 | + |
| 924 | + for (int i = 0; i < length; i++) { |
| 925 | + TableRow orig = (i < list1.size()) ? list1.get(i) : null; |
| 926 | + TableRow delta = (i < list2.size()) ? list2.get(i) : null; |
| 927 | + // fail if any is shorter |
| 928 | + Preconditions.checkArgumentNotNull(orig); |
| 929 | + Preconditions.checkArgumentNotNull(delta); |
| 930 | + |
| 931 | + mergedList.add(mergeNewFields(orig, delta)); |
| 932 | + } |
| 933 | + return mergedList; |
| 934 | + } |
| 935 | + |
| 936 | + public static ByteString mergeNewFields( |
| 937 | + ByteString tableRowProto, |
| 938 | + DescriptorProtos.DescriptorProto descriptorProto, |
| 939 | + TableSchema tableSchema, |
| 940 | + SchemaInformation schemaInformation, |
| 941 | + TableRow unknownFields, |
| 942 | + boolean ignoreUnknownValues) |
| 943 | + throws TableRowToStorageApiProto.SchemaConversionException { |
| 944 | + if (unknownFields == null || unknownFields.isEmpty()) { |
| 945 | + // nothing to do here |
| 946 | + return tableRowProto; |
| 947 | + } |
| 948 | + // check if unknownFields contains repeated struct, merge |
| 949 | + boolean hasRepeatedStruct = |
| 950 | + unknownFields.entrySet().stream() |
| 951 | + .anyMatch( |
| 952 | + entry -> |
| 953 | + entry.getValue() instanceof List |
| 954 | + && !((List<?>) entry.getValue()).isEmpty() |
| 955 | + && ((List<?>) entry.getValue()).get(0) instanceof TableRow); |
| 956 | + if (!hasRepeatedStruct) { |
| 957 | + Descriptor descriptorIgnoreRequired = null; |
| 958 | + try { |
| 959 | + descriptorIgnoreRequired = |
| 960 | + TableRowToStorageApiProto.getDescriptorFromTableSchema(tableSchema, false, false); |
| 961 | + } catch (DescriptorValidationException e) { |
| 962 | + throw new RuntimeException(e); |
| 963 | + } |
| 964 | + ByteString unknownFieldsProto = |
| 965 | + messageFromTableRow( |
| 966 | + schemaInformation, |
| 967 | + descriptorIgnoreRequired, |
| 968 | + unknownFields, |
| 969 | + ignoreUnknownValues, |
| 970 | + true, |
| 971 | + null, |
| 972 | + null, |
| 973 | + null) |
| 974 | + .toByteString(); |
| 975 | + return tableRowProto.concat(unknownFieldsProto); |
| 976 | + } |
| 977 | + |
| 978 | + DynamicMessage message = null; |
| 979 | + Descriptor descriptor = null; |
| 980 | + try { |
| 981 | + descriptor = wrapDescriptorProto(descriptorProto); |
| 982 | + } catch (DescriptorValidationException e) { |
| 983 | + throw new RuntimeException(e); |
| 984 | + } |
| 985 | + try { |
| 986 | + message = DynamicMessage.parseFrom(descriptor, tableRowProto); |
| 987 | + } catch (InvalidProtocolBufferException e) { |
| 988 | + throw new RuntimeException(e); |
| 989 | + } |
| 990 | + TableRow original = |
| 991 | + TableRowToStorageApiProto.tableRowFromMessage(message, true, Predicates.alwaysTrue()); |
| 992 | + Map<String, Descriptors.FieldDescriptor> fieldDescriptors = |
| 993 | + descriptor.getFields().stream() |
| 994 | + .collect(Collectors.toMap(Descriptors.FieldDescriptor::getName, Functions.identity())); |
| 995 | + // recover cdc data |
| 996 | + String cdcType = null; |
| 997 | + String sequence = null; |
| 998 | + if (fieldDescriptors.get(StorageApiCDC.CHANGE_TYPE_COLUMN) != null |
| 999 | + && fieldDescriptors.get(StorageApiCDC.CHANGE_SQN_COLUMN) != null) { |
| 1000 | + cdcType = |
| 1001 | + (String) |
| 1002 | + message.getField( |
| 1003 | + Preconditions.checkStateNotNull( |
| 1004 | + fieldDescriptors.get(StorageApiCDC.CHANGE_TYPE_COLUMN))); |
| 1005 | + sequence = |
| 1006 | + (String) |
| 1007 | + message.getField( |
| 1008 | + Preconditions.checkStateNotNull( |
| 1009 | + fieldDescriptors.get(StorageApiCDC.CHANGE_SQN_COLUMN))); |
| 1010 | + } |
| 1011 | + TableRow merged = TableRowToStorageApiProto.mergeNewFields(original, unknownFields); |
| 1012 | + DynamicMessage dynamicMessage = |
| 1013 | + TableRowToStorageApiProto.messageFromTableRow( |
| 1014 | + schemaInformation, |
| 1015 | + descriptor, |
| 1016 | + merged, |
| 1017 | + ignoreUnknownValues, |
| 1018 | + false, |
| 1019 | + null, |
| 1020 | + cdcType, |
| 1021 | + sequence); |
| 1022 | + return dynamicMessage.toByteString(); |
| 1023 | + } |
| 1024 | + |
871 | 1025 | private static @Nullable Object messageValueFromFieldValue( |
872 | 1026 | SchemaInformation schemaInformation, |
873 | 1027 | FieldDescriptor fieldDescriptor, |
|
0 commit comments