Skip to content

Commit f2c600c

Browse files
authored
[DataStreamToSpanner] Improvement: Adding group by key step before Spanner write step (#3249)
* Adding group by key step before Spanner write step * Renaming class name * Added unit tests * Correcting UTs * Correcting a bug * Fixing a bug * Correcting a bug * Adding more Unit test cases * Addressing comments
1 parent 6125ad5 commit f2c600c

File tree

7 files changed

+504
-27
lines changed

7 files changed

+504
-27
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (C) 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
package com.google.cloud.teleport.v2.templates;
17+
18+
import static com.google.cloud.teleport.v2.templates.constants.DatastreamToSpannerConstants.CONVERSION_ERRORS_COUNTER_NAME;
19+
20+
import com.fasterxml.jackson.databind.DeserializationFeature;
21+
import com.fasterxml.jackson.databind.JsonNode;
22+
import com.fasterxml.jackson.databind.ObjectMapper;
23+
import com.google.cloud.teleport.v2.spanner.ddl.Ddl;
24+
import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventSpannerConvertor;
25+
import com.google.cloud.teleport.v2.templates.constants.DatastreamToSpannerConstants;
26+
import com.google.cloud.teleport.v2.templates.datastream.ChangeEventConvertor;
27+
import com.google.cloud.teleport.v2.templates.datastream.DatastreamConstants;
28+
import com.google.cloud.teleport.v2.values.FailsafeElement;
29+
import org.apache.beam.sdk.metrics.Counter;
30+
import org.apache.beam.sdk.metrics.Metrics;
31+
import org.apache.beam.sdk.transforms.DoFn;
32+
import org.apache.beam.sdk.values.KV;
33+
import org.apache.beam.sdk.values.PCollectionView;
34+
import org.slf4j.Logger;
35+
import org.slf4j.LoggerFactory;
36+
37+
public class CreateKeyValuePairsWithPrimaryKeyHashDoFn
38+
extends DoFn<FailsafeElement<String, String>, KV<Long, FailsafeElement<String, String>>> {
39+
40+
private static final Logger LOG =
41+
LoggerFactory.getLogger(CreateKeyValuePairsWithPrimaryKeyHashDoFn.class);
42+
43+
private final PCollectionView<Ddl> ddlView;
44+
45+
// Jackson Object mapper.
46+
private transient ObjectMapper mapper;
47+
48+
private final Counter conversionErrors =
49+
Metrics.counter(SpannerTransactionWriterDoFn.class, CONVERSION_ERRORS_COUNTER_NAME);
50+
51+
public CreateKeyValuePairsWithPrimaryKeyHashDoFn(PCollectionView<Ddl> ddlView) {
52+
this.ddlView = ddlView;
53+
}
54+
55+
@Setup
56+
public void setup() {
57+
mapper = new ObjectMapper();
58+
mapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
59+
}
60+
61+
@ProcessElement
62+
public void processElement(ProcessContext c) {
63+
FailsafeElement<String, String> msg = c.element();
64+
String tableName = "";
65+
try {
66+
// msg.getPayload() contains transformed change event and hence msg.getPayload() should be
67+
// used and not msg.getOriginalPayload()
68+
JsonNode changeEvent = mapper.readTree(msg.getPayload());
69+
Ddl ddl = c.sideInput(ddlView);
70+
71+
tableName = changeEvent.get(DatastreamConstants.EVENT_TABLE_NAME_KEY).asText();
72+
ChangeEventConvertor.convertChangeEventColumnKeysToLowerCase(changeEvent);
73+
ChangeEventConvertor.verifySpannerSchema(ddl, changeEvent);
74+
com.google.cloud.spanner.Key primaryKey =
75+
ChangeEventSpannerConvertor.changeEventToPrimaryKey(
76+
tableName, ddl, changeEvent, /* convertNameToLowerCase= */ true);
77+
String finalKeyString = tableName + "_" + primaryKey.toString();
78+
Long finalKey = (long) finalKeyString.hashCode();
79+
c.output(KV.of(finalKey, msg));
80+
} catch (Exception e) {
81+
LOG.error(
82+
"Error while converting change event to primary key hash for tableName=" + tableName, e);
83+
// Errors that result during Event conversions are not retryable.
84+
// Making a copy, as the input must not be mutated.
85+
FailsafeElement<String, String> output = FailsafeElement.of(msg);
86+
output.setErrorMessage(e.getMessage());
87+
c.output(DatastreamToSpannerConstants.PERMANENT_ERROR_TAG, output);
88+
conversionErrors.inc();
89+
}
90+
}
91+
}

v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SpannerTransactionWriter.java

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,26 @@
1717

1818
import com.google.auto.value.AutoValue;
1919
import com.google.cloud.Timestamp;
20+
import com.google.cloud.teleport.v2.coders.FailsafeElementCoder;
2021
import com.google.cloud.teleport.v2.spanner.ddl.Ddl;
2122
import com.google.cloud.teleport.v2.templates.constants.DatastreamToSpannerConstants;
2223
import com.google.cloud.teleport.v2.values.FailsafeElement;
2324
import com.google.common.base.Preconditions;
2425
import com.google.common.collect.ImmutableMap;
2526
import java.util.Arrays;
27+
import java.util.List;
2628
import java.util.Map;
2729
import org.apache.beam.sdk.Pipeline;
30+
import org.apache.beam.sdk.coders.KvCoder;
31+
import org.apache.beam.sdk.coders.StringUtf8Coder;
32+
import org.apache.beam.sdk.coders.VarLongCoder;
2833
import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig;
34+
import org.apache.beam.sdk.transforms.Flatten;
2935
import org.apache.beam.sdk.transforms.PTransform;
3036
import org.apache.beam.sdk.transforms.ParDo;
37+
import org.apache.beam.sdk.transforms.Reshuffle;
3138
import org.apache.beam.sdk.values.PCollection;
39+
import org.apache.beam.sdk.values.PCollectionList;
3240
import org.apache.beam.sdk.values.PCollectionTuple;
3341
import org.apache.beam.sdk.values.PCollectionView;
3442
import org.apache.beam.sdk.values.PInput;
@@ -93,29 +101,53 @@ public SpannerTransactionWriter(
93101
@Override
94102
public SpannerTransactionWriter.Result expand(
95103
PCollection<FailsafeElement<String, String>> input) {
96-
PCollectionTuple spannerWriteResults =
104+
PCollectionTuple keyedEvents =
97105
input.apply(
98-
"Write Mutations",
99-
ParDo.of(
100-
new SpannerTransactionWriterDoFn(
101-
spannerConfig,
102-
shadowTableSpannerConfig,
103-
ddlView,
104-
shadowTableDdlView,
105-
shadowTablePrefix,
106-
sourceType,
107-
isRegularRunMode))
108-
.withSideInputs(ddlView, shadowTableDdlView)
106+
"Key By PK Hash",
107+
ParDo.of(new CreateKeyValuePairsWithPrimaryKeyHashDoFn(ddlView))
108+
.withSideInputs(ddlView)
109109
.withOutputTags(
110-
DatastreamToSpannerConstants.SUCCESSFUL_EVENT_TAG,
111-
TupleTagList.of(
112-
Arrays.asList(
113-
DatastreamToSpannerConstants.PERMANENT_ERROR_TAG,
114-
DatastreamToSpannerConstants.RETRYABLE_ERROR_TAG))));
110+
DatastreamToSpannerConstants.SUCCESSFUL_KEYED_EVENT_TAG,
111+
TupleTagList.of(List.of(DatastreamToSpannerConstants.PERMANENT_ERROR_TAG))));
112+
PCollectionTuple spannerWriteResults =
113+
keyedEvents
114+
.get(DatastreamToSpannerConstants.SUCCESSFUL_KEYED_EVENT_TAG)
115+
.setCoder(
116+
KvCoder.of(
117+
VarLongCoder.of(),
118+
FailsafeElementCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())))
119+
.apply("Reshuffle Keyed Events", Reshuffle.of())
120+
.apply(
121+
"Write Mutations",
122+
ParDo.of(
123+
new SpannerTransactionWriterDoFn(
124+
spannerConfig,
125+
shadowTableSpannerConfig,
126+
ddlView,
127+
shadowTableDdlView,
128+
shadowTablePrefix,
129+
sourceType,
130+
isRegularRunMode))
131+
.withSideInputs(ddlView, shadowTableDdlView)
132+
.withOutputTags(
133+
DatastreamToSpannerConstants.SUCCESSFUL_EVENT_TAG,
134+
TupleTagList.of(
135+
Arrays.asList(
136+
DatastreamToSpannerConstants.PERMANENT_ERROR_TAG,
137+
DatastreamToSpannerConstants.RETRYABLE_ERROR_TAG))));
138+
139+
PCollection<FailsafeElement<String, String>> keyedEventsErrorRecords =
140+
keyedEvents.get(DatastreamToSpannerConstants.PERMANENT_ERROR_TAG);
141+
142+
PCollection<FailsafeElement<String, String>> permanentErrorRecords =
143+
PCollectionList.of(
144+
spannerWriteResults.get(DatastreamToSpannerConstants.PERMANENT_ERROR_TAG))
145+
.and(keyedEventsErrorRecords)
146+
.apply(Flatten.pCollections());
115147

116148
return Result.create(
117149
spannerWriteResults.get(DatastreamToSpannerConstants.SUCCESSFUL_EVENT_TAG),
118-
spannerWriteResults.get(DatastreamToSpannerConstants.PERMANENT_ERROR_TAG),
150+
permanentErrorRecords,
119151
spannerWriteResults.get(DatastreamToSpannerConstants.RETRYABLE_ERROR_TAG));
120152
}
121153

v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/SpannerTransactionWriterDoFn.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.apache.beam.sdk.metrics.Metrics;
5858
import org.apache.beam.sdk.options.PipelineOptions;
5959
import org.apache.beam.sdk.transforms.DoFn;
60+
import org.apache.beam.sdk.values.KV;
6061
import org.apache.beam.sdk.values.PCollectionView;
6162
import org.apache.beam.sdk.values.TupleTag;
6263
import org.joda.time.Duration;
@@ -77,8 +78,8 @@
7778
* <p>Change events that failed to be written will be pushed onto the secondary output tagged with
7879
* PERMANENT_ERROR_TAG/RETRYABLE_ERROR_TAG along with the exception that caused the failure.
7980
*/
80-
class SpannerTransactionWriterDoFn extends DoFn<FailsafeElement<String, String>, Timestamp>
81-
implements Serializable {
81+
class SpannerTransactionWriterDoFn
82+
extends DoFn<KV<Long, FailsafeElement<String, String>>, Timestamp> implements Serializable {
8283

8384
// TODO - Change Cloud Spanner nomenclature in code used to read DDL.
8485

@@ -239,7 +240,7 @@ public void teardown() {
239240

240241
@ProcessElement
241242
public void processElement(ProcessContext c) {
242-
FailsafeElement<String, String> msg = c.element();
243+
FailsafeElement<String, String> msg = c.element().getValue();
243244
Ddl ddl = c.sideInput(ddlView);
244245
// TODO: pass shadow table ddl to shdaow tble mutaiton generator and sequence reader.
245246
Ddl shadowTableDdl = c.sideInput(shadowTableDdlView);

v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/constants/DatastreamToSpannerConstants.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.google.cloud.Timestamp;
1919
import com.google.cloud.teleport.v2.values.FailsafeElement;
20+
import org.apache.beam.sdk.values.KV;
2021
import org.apache.beam.sdk.values.TupleTag;
2122

2223
/** Class to maintain all the constants used in the pipeline. */
@@ -42,6 +43,10 @@ public class DatastreamToSpannerConstants {
4243
/* The Tag for Successful mutations. */
4344
public static final TupleTag<Timestamp> SUCCESSFUL_EVENT_TAG = new TupleTag<Timestamp>() {};
4445

46+
/* The tag for successfully keyed events. */
47+
public static final TupleTag<KV<Long, FailsafeElement<String, String>>>
48+
SUCCESSFUL_KEYED_EVENT_TAG = new TupleTag<KV<Long, FailsafeElement<String, String>>>() {};
49+
4550
/* Max DoFns per dataflow worker in a streaming pipeline. */
4651
public static final int MAX_DOFN_PER_WORKER = 500;
4752

v2/datastream-to-spanner/src/main/java/com/google/cloud/teleport/v2/templates/datastream/ChangeEventConvertor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class ChangeEventConvertor {
3737

3838
private ChangeEventConvertor() {}
3939

40-
static void verifySpannerSchema(Ddl ddl, JsonNode changeEvent)
40+
public static void verifySpannerSchema(Ddl ddl, JsonNode changeEvent)
4141
throws ChangeEventConvertorException, InvalidChangeEventException, DroppedTableException {
4242
String tableName = changeEvent.get(DatastreamConstants.EVENT_TABLE_NAME_KEY).asText();
4343
if (ddl.table(tableName) == null) {
@@ -66,7 +66,7 @@ static void verifySpannerSchema(Ddl ddl, JsonNode changeEvent)
6666
}
6767
}
6868

69-
static void convertChangeEventColumnKeysToLowerCase(JsonNode changeEvent)
69+
public static void convertChangeEventColumnKeysToLowerCase(JsonNode changeEvent)
7070
throws ChangeEventConvertorException, InvalidChangeEventException {
7171
List<String> changeEventKeys = ChangeEventUtils.getEventColumnKeys(changeEvent);
7272
ObjectNode jsonNode = (ObjectNode) changeEvent;

0 commit comments

Comments
 (0)