Skip to content

Commit 359c9eb

Browse files
authored
Introduce Schema Registry Functionality to Managed KafkaIO Write. (#35644)
* Add testing for Managed Schema Registry support * Add testing that runs on Dataflow * Clean up test and use apache beam testing resources * Spotless * Trigger GitHub Actions. No Code Changes * Push Write changes for testing * add test to read and write with Managed KafkaIO using SR * Testing Write Transform * Add @ignore for faster testing. WILL REMOVE BEFORE MERGE. * Finish the Write schema transform provider and add tests * Refactor write class to use a generic method for the conversion function * Add extra logging and clean up variable names to address comments.
1 parent 60630af commit 359c9eb

File tree

4 files changed

+146
-34
lines changed

4 files changed

+146
-34
lines changed

sdks/java/io/kafka/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ dependencies {
7474
implementation (group: 'com.google.cloud.hosted.kafka', name: 'managed-kafka-auth-login-handler', version: '1.0.5') {
7575
// "kafka-clients" has to be provided since user can use its own version.
7676
exclude group: 'org.apache.kafka', module: 'kafka-clients'
77+
// "kafka-schema-registry-client must be excluded per the Google Cloud documentation:
78+
// https://cloud.google.com/managed-service-for-apache-kafka/docs/quickstart-avro#configure_and_run_the_producer
79+
exclude group: "io.confluent", module: "kafka-schema-registry-client"
7780
}
7881
implementation ("io.confluent:kafka-avro-serializer:${confluentVersion}") {
7982
// zookeeper depends on "spotbugs-annotations:3.1.9" which clashes with current

sdks/java/io/kafka/kafka-integration-test.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies {
3333
// instead, rely on io/kafka/build.gradle's custom configurations with forced kafka-client resolutionStrategy
3434
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
3535
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
36+
testImplementation library.java.avro
3637
}
3738

3839
configurations.create("kafkaVersion$undelimited")

sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java

Lines changed: 99 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121

2222
import com.google.auto.service.AutoService;
2323
import com.google.auto.value.AutoValue;
24+
import io.confluent.kafka.serializers.KafkaAvroSerializer;
2425
import java.io.Serializable;
2526
import java.util.Collections;
2627
import java.util.HashMap;
2728
import java.util.List;
2829
import java.util.Map;
2930
import java.util.Set;
3031
import javax.annotation.Nullable;
32+
import org.apache.avro.generic.GenericRecord;
3133
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
34+
import org.apache.beam.sdk.coders.ByteArrayCoder;
35+
import org.apache.beam.sdk.coders.KvCoder;
36+
import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
3237
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
3338
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
3439
import org.apache.beam.sdk.metrics.Counter;
@@ -74,6 +79,8 @@ public class KafkaWriteSchemaTransformProvider
7479
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
7580
public static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
7681
new TupleTag<KV<byte[], byte[]>>() {};
82+
public static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
83+
new TupleTag<KV<byte[], GenericRecord>>() {};
7784
private static final Logger LOG =
7885
LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class);
7986

@@ -118,29 +125,32 @@ Row getConfigurationRow() {
118125
}
119126
}
120127

121-
public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
122-
private final SerializableFunction<Row, byte[]> toBytesFn;
128+
public abstract static class BaseKafkaWriterFn<T> extends DoFn<Row, KV<byte[], T>> {
129+
private final SerializableFunction<Row, T> conversionFn;
123130
private final Counter errorCounter;
124131
private Long errorsInBundle = 0L;
125132
private final boolean handleErrors;
126133
private final Schema errorSchema;
134+
private final TupleTag<KV<byte[], T>> successTag;
127135

128-
public ErrorCounterFn(
136+
public BaseKafkaWriterFn(
129137
String name,
130-
SerializableFunction<Row, byte[]> toBytesFn,
138+
SerializableFunction<Row, T> conversionFn,
131139
Schema errorSchema,
132-
boolean handleErrors) {
133-
this.toBytesFn = toBytesFn;
140+
boolean handleErrors,
141+
TupleTag<KV<byte[], T>> successTag) {
142+
this.conversionFn = conversionFn;
134143
this.errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name);
135144
this.handleErrors = handleErrors;
136145
this.errorSchema = errorSchema;
146+
this.successTag = successTag;
137147
}
138148

139149
@ProcessElement
140150
public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) {
141-
KV<byte[], byte[]> output = null;
151+
KV<byte[], T> output = null;
142152
try {
143-
output = KV.of(new byte[1], toBytesFn.apply(row));
153+
output = KV.of(new byte[1], conversionFn.apply(row));
144154
} catch (Exception e) {
145155
if (!handleErrors) {
146156
throw new RuntimeException(e);
@@ -150,7 +160,7 @@ public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) {
150160
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e));
151161
}
152162
if (output != null) {
153-
receiver.get(OUTPUT_TAG).output(output);
163+
receiver.get(successTag).output(output);
154164
}
155165
}
156166

@@ -161,13 +171,35 @@ public void finish() {
161171
}
162172
}
163173

174+
public static class ErrorCounterFn extends BaseKafkaWriterFn<byte[]> {
175+
public ErrorCounterFn(
176+
String name,
177+
SerializableFunction<Row, byte[]> toBytesFn,
178+
Schema errorSchema,
179+
boolean handleErrors) {
180+
super(name, toBytesFn, errorSchema, handleErrors, OUTPUT_TAG);
181+
}
182+
}
183+
184+
public static class GenericRecordErrorCounterFn extends BaseKafkaWriterFn<GenericRecord> {
185+
public GenericRecordErrorCounterFn(
186+
String name,
187+
SerializableFunction<Row, GenericRecord> toGenericRecordsFn,
188+
Schema errorSchema,
189+
boolean handleErrors) {
190+
super(name, toGenericRecordsFn, errorSchema, handleErrors, RECORD_OUTPUT_TAG);
191+
}
192+
}
193+
164194
@SuppressWarnings({
165195
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
166196
})
167197
@Override
168198
public PCollectionRowTuple expand(PCollectionRowTuple input) {
169199
Schema inputSchema = input.get("input").getSchema();
200+
org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(inputSchema);
170201
final SerializableFunction<Row, byte[]> toBytesFn;
202+
SerializableFunction<Row, GenericRecord> toGenericRecordsFn = null;
171203
if (configuration.getFormat().equals("RAW")) {
172204
int numFields = inputSchema.getFields().size();
173205
if (numFields != 1) {
@@ -198,36 +230,70 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
198230
throw new IllegalArgumentException(
199231
"At least a descriptorPath or a proto Schema is required.");
200232
}
201-
202233
} else {
203-
toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
234+
if (configuration.getProducerConfigUpdates() != null
235+
&& configuration.getProducerConfigUpdates().containsKey("schema.registry.url")) {
236+
toGenericRecordsFn = AvroUtils.getRowToGenericRecordFunction(avroSchema);
237+
toBytesFn = null;
238+
} else {
239+
toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
240+
}
204241
}
205242

206243
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());
207244
final Map<String, String> configOverrides = configuration.getProducerConfigUpdates();
208245
Schema errorSchema = ErrorHandling.errorSchema(inputSchema);
209-
PCollectionTuple outputTuple =
210-
input
211-
.get("input")
212-
.apply(
213-
"Map rows to Kafka messages",
214-
ParDo.of(
215-
new ErrorCounterFn(
216-
"Kafka-write-error-counter", toBytesFn, errorSchema, handleErrors))
217-
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
218-
219-
outputTuple
220-
.get(OUTPUT_TAG)
221-
.apply(
222-
KafkaIO.<byte[], byte[]>write()
223-
.withTopic(configuration.getTopic())
224-
.withBootstrapServers(configuration.getBootstrapServers())
225-
.withProducerConfigUpdates(
226-
configOverrides == null
227-
? new HashMap<>()
228-
: new HashMap<String, Object>(configOverrides))
229-
.withKeySerializer(ByteArraySerializer.class)
230-
.withValueSerializer(ByteArraySerializer.class));
246+
PCollectionTuple outputTuple;
247+
if (toGenericRecordsFn != null) {
248+
LOG.info("Convert to GenericRecord with schema {}", avroSchema);
249+
outputTuple =
250+
input
251+
.get("input")
252+
.apply(
253+
"Map rows to Kafka messages",
254+
ParDo.of(
255+
new GenericRecordErrorCounterFn(
256+
"Kafka-write-error-counter",
257+
toGenericRecordsFn,
258+
errorSchema,
259+
handleErrors))
260+
.withOutputTags(RECORD_OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
261+
HashMap<String, Object> producerConfig = new HashMap<>(configOverrides);
262+
outputTuple
263+
.get(RECORD_OUTPUT_TAG)
264+
.setCoder(KvCoder.of(ByteArrayCoder.of(), AvroCoder.of(avroSchema)))
265+
.apply(
266+
"Map Rows to GenericRecords",
267+
KafkaIO.<byte[], GenericRecord>write()
268+
.withTopic(configuration.getTopic())
269+
.withBootstrapServers(configuration.getBootstrapServers())
270+
.withProducerConfigUpdates(producerConfig)
271+
.withKeySerializer(ByteArraySerializer.class)
272+
.withValueSerializer((Class) KafkaAvroSerializer.class));
273+
} else {
274+
outputTuple =
275+
input
276+
.get("input")
277+
.apply(
278+
"Map rows to Kafka messages",
279+
ParDo.of(
280+
new ErrorCounterFn(
281+
"Kafka-write-error-counter", toBytesFn, errorSchema, handleErrors))
282+
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
283+
284+
outputTuple
285+
.get(OUTPUT_TAG)
286+
.apply(
287+
KafkaIO.<byte[], byte[]>write()
288+
.withTopic(configuration.getTopic())
289+
.withBootstrapServers(configuration.getBootstrapServers())
290+
.withProducerConfigUpdates(
291+
configOverrides == null
292+
? new HashMap<>()
293+
: new HashMap<String, Object>(configOverrides))
294+
.withKeySerializer(ByteArraySerializer.class)
295+
.withValueSerializer(ByteArraySerializer.class));
296+
}
231297

232298
// TODO: include output from KafkaIO Write once updated from PDone
233299
PCollection<Row> errorOutput =

sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@
2424
import java.util.Collections;
2525
import java.util.List;
2626
import java.util.Objects;
27+
import org.apache.avro.generic.GenericData;
28+
import org.apache.avro.generic.GenericRecord;
2729
import org.apache.beam.sdk.Pipeline;
30+
import org.apache.beam.sdk.coders.ByteArrayCoder;
31+
import org.apache.beam.sdk.coders.KvCoder;
32+
import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
33+
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
2834
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
2935
import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn;
36+
import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.GenericRecordErrorCounterFn;
3037
import org.apache.beam.sdk.managed.Managed;
3138
import org.apache.beam.sdk.schemas.Schema;
3239
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
@@ -53,6 +60,8 @@ public class KafkaWriteSchemaTransformProviderTest {
5360

5461
private static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
5562
KafkaWriteSchemaTransformProvider.OUTPUT_TAG;
63+
private static final TupleTag<KV<byte[], GenericRecord>> RECORD_OUTPUT_TAG =
64+
KafkaWriteSchemaTransformProvider.RECORD_OUTPUT_TAG;
5665
private static final TupleTag<Row> ERROR_TAG = KafkaWriteSchemaTransformProvider.ERROR_TAG;
5766

5867
private static final Schema BEAMSCHEMA =
@@ -126,7 +135,8 @@ public class KafkaWriteSchemaTransformProviderTest {
126135
getClass().getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
127136
.getPath(),
128137
"MyMessage");
129-
138+
final SerializableFunction<Row, GenericRecord> recordValueMapper =
139+
AvroUtils.getRowToGenericRecordFunction(AvroUtils.toAvroSchema(BEAMSCHEMA));
130140
@Rule public transient TestPipeline p = TestPipeline.create();
131141

132142
@Test
@@ -198,6 +208,38 @@ public void testKafkaErrorFnProtoSuccess() {
198208
+ " bool active = 3;\n"
199209
+ "}";
200210

211+
@Test
212+
public void testKafkaRecordErrorFnSuccess() throws Exception {
213+
org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(BEAMSCHEMA);
214+
215+
GenericRecord record1 = new GenericData.Record(avroSchema);
216+
GenericRecord record2 = new GenericData.Record(avroSchema);
217+
GenericRecord record3 = new GenericData.Record(avroSchema);
218+
record1.put("name", "a");
219+
record2.put("name", "b");
220+
record3.put("name", "c");
221+
222+
List<KV<byte[], GenericRecord>> msg =
223+
Arrays.asList(
224+
KV.of(new byte[1], record1), KV.of(new byte[1], record2), KV.of(new byte[1], record3));
225+
226+
PCollection<Row> input = p.apply(Create.of(ROWS));
227+
Schema errorSchema = ErrorHandling.errorSchema(BEAMSCHEMA);
228+
PCollectionTuple output =
229+
input.apply(
230+
ParDo.of(
231+
new GenericRecordErrorCounterFn(
232+
"Kafka-write-error-counter", recordValueMapper, errorSchema, true))
233+
.withOutputTags(RECORD_OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
234+
235+
output.get(ERROR_TAG).setRowSchema(errorSchema);
236+
output
237+
.get(RECORD_OUTPUT_TAG)
238+
.setCoder(KvCoder.of(ByteArrayCoder.of(), AvroCoder.of(avroSchema)));
239+
PAssert.that(output.get(RECORD_OUTPUT_TAG)).containsInAnyOrder(msg);
240+
p.run().waitUntilFinish();
241+
}
242+
201243
@Test
202244
public void testBuildTransformWithManaged() {
203245
List<String> configs =

0 commit comments

Comments
 (0)