2121
2222import com .google .auto .service .AutoService ;
2323import com .google .auto .value .AutoValue ;
24+ import io .confluent .kafka .serializers .KafkaAvroSerializer ;
2425import java .io .Serializable ;
2526import java .util .Collections ;
2627import java .util .HashMap ;
2728import java .util .List ;
2829import java .util .Map ;
2930import java .util .Set ;
3031import javax .annotation .Nullable ;
32+ import org .apache .avro .generic .GenericRecord ;
3133import org .apache .beam .model .pipeline .v1 .ExternalTransforms ;
34+ import org .apache .beam .sdk .coders .ByteArrayCoder ;
35+ import org .apache .beam .sdk .coders .KvCoder ;
36+ import org .apache .beam .sdk .extensions .avro .coders .AvroCoder ;
3237import org .apache .beam .sdk .extensions .avro .schemas .utils .AvroUtils ;
3338import org .apache .beam .sdk .extensions .protobuf .ProtoByteUtils ;
3439import org .apache .beam .sdk .metrics .Counter ;
@@ -74,6 +79,8 @@ public class KafkaWriteSchemaTransformProvider
7479 public static final TupleTag <Row > ERROR_TAG = new TupleTag <Row >() {};
7580 public static final TupleTag <KV <byte [], byte []>> OUTPUT_TAG =
7681 new TupleTag <KV <byte [], byte []>>() {};
82+ public static final TupleTag <KV <byte [], GenericRecord >> RECORD_OUTPUT_TAG =
83+ new TupleTag <KV <byte [], GenericRecord >>() {};
7784 private static final Logger LOG =
7885 LoggerFactory .getLogger (KafkaWriteSchemaTransformProvider .class );
7986
@@ -118,29 +125,32 @@ Row getConfigurationRow() {
118125 }
119126 }
120127
121- public static class ErrorCounterFn extends DoFn <Row , KV <byte [], byte [] >> {
122- private final SerializableFunction <Row , byte []> toBytesFn ;
128+ public abstract static class BaseKafkaWriterFn < T > extends DoFn <Row , KV <byte [], T >> {
129+ private final SerializableFunction <Row , T > conversionFn ;
123130 private final Counter errorCounter ;
124131 private Long errorsInBundle = 0L ;
125132 private final boolean handleErrors ;
126133 private final Schema errorSchema ;
134+ private final TupleTag <KV <byte [], T >> successTag ;
127135
128- public ErrorCounterFn (
136+ public BaseKafkaWriterFn (
129137 String name ,
130- SerializableFunction <Row , byte []> toBytesFn ,
138+ SerializableFunction <Row , T > conversionFn ,
131139 Schema errorSchema ,
132- boolean handleErrors ) {
133- this .toBytesFn = toBytesFn ;
140+ boolean handleErrors ,
141+ TupleTag <KV <byte [], T >> successTag ) {
142+ this .conversionFn = conversionFn ;
134143 this .errorCounter = Metrics .counter (KafkaWriteSchemaTransformProvider .class , name );
135144 this .handleErrors = handleErrors ;
136145 this .errorSchema = errorSchema ;
146+ this .successTag = successTag ;
137147 }
138148
139149 @ ProcessElement
140150 public void process (@ DoFn .Element Row row , MultiOutputReceiver receiver ) {
141- KV <byte [], byte [] > output = null ;
151+ KV <byte [], T > output = null ;
142152 try {
143- output = KV .of (new byte [1 ], toBytesFn .apply (row ));
153+ output = KV .of (new byte [1 ], conversionFn .apply (row ));
144154 } catch (Exception e ) {
145155 if (!handleErrors ) {
146156 throw new RuntimeException (e );
@@ -150,7 +160,7 @@ public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) {
150160 receiver .get (ERROR_TAG ).output (ErrorHandling .errorRecord (errorSchema , row , e ));
151161 }
152162 if (output != null ) {
153- receiver .get (OUTPUT_TAG ).output (output );
163+ receiver .get (successTag ).output (output );
154164 }
155165 }
156166
@@ -161,13 +171,35 @@ public void finish() {
161171 }
162172 }
163173
174+ public static class ErrorCounterFn extends BaseKafkaWriterFn <byte []> {
175+ public ErrorCounterFn (
176+ String name ,
177+ SerializableFunction <Row , byte []> toBytesFn ,
178+ Schema errorSchema ,
179+ boolean handleErrors ) {
180+ super (name , toBytesFn , errorSchema , handleErrors , OUTPUT_TAG );
181+ }
182+ }
183+
184+ public static class GenericRecordErrorCounterFn extends BaseKafkaWriterFn <GenericRecord > {
185+ public GenericRecordErrorCounterFn (
186+ String name ,
187+ SerializableFunction <Row , GenericRecord > toGenericRecordsFn ,
188+ Schema errorSchema ,
189+ boolean handleErrors ) {
190+ super (name , toGenericRecordsFn , errorSchema , handleErrors , RECORD_OUTPUT_TAG );
191+ }
192+ }
193+
164194 @ SuppressWarnings ({
165195 "nullness" // TODO(https://github.com/apache/beam/issues/20497)
166196 })
167197 @ Override
168198 public PCollectionRowTuple expand (PCollectionRowTuple input ) {
169199 Schema inputSchema = input .get ("input" ).getSchema ();
200+ org .apache .avro .Schema avroSchema = AvroUtils .toAvroSchema (inputSchema );
170201 final SerializableFunction <Row , byte []> toBytesFn ;
202+ SerializableFunction <Row , GenericRecord > toGenericRecordsFn = null ;
171203 if (configuration .getFormat ().equals ("RAW" )) {
172204 int numFields = inputSchema .getFields ().size ();
173205 if (numFields != 1 ) {
@@ -198,36 +230,70 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
198230 throw new IllegalArgumentException (
199231 "At least a descriptorPath or a proto Schema is required." );
200232 }
201-
202233 } else {
203- toBytesFn = AvroUtils .getRowToAvroBytesFunction (inputSchema );
234+ if (configuration .getProducerConfigUpdates () != null
235+ && configuration .getProducerConfigUpdates ().containsKey ("schema.registry.url" )) {
236+ toGenericRecordsFn = AvroUtils .getRowToGenericRecordFunction (avroSchema );
237+ toBytesFn = null ;
238+ } else {
239+ toBytesFn = AvroUtils .getRowToAvroBytesFunction (inputSchema );
240+ }
204241 }
205242
206243 boolean handleErrors = ErrorHandling .hasOutput (configuration .getErrorHandling ());
207244 final Map <String , String > configOverrides = configuration .getProducerConfigUpdates ();
208245 Schema errorSchema = ErrorHandling .errorSchema (inputSchema );
209- PCollectionTuple outputTuple =
210- input
211- .get ("input" )
212- .apply (
213- "Map rows to Kafka messages" ,
214- ParDo .of (
215- new ErrorCounterFn (
216- "Kafka-write-error-counter" , toBytesFn , errorSchema , handleErrors ))
217- .withOutputTags (OUTPUT_TAG , TupleTagList .of (ERROR_TAG )));
218-
219- outputTuple
220- .get (OUTPUT_TAG )
221- .apply (
222- KafkaIO .<byte [], byte []>write ()
223- .withTopic (configuration .getTopic ())
224- .withBootstrapServers (configuration .getBootstrapServers ())
225- .withProducerConfigUpdates (
226- configOverrides == null
227- ? new HashMap <>()
228- : new HashMap <String , Object >(configOverrides ))
229- .withKeySerializer (ByteArraySerializer .class )
230- .withValueSerializer (ByteArraySerializer .class ));
246+ PCollectionTuple outputTuple ;
247+ if (toGenericRecordsFn != null ) {
248+ LOG .info ("Convert to GenericRecord with schema {}" , avroSchema );
249+ outputTuple =
250+ input
251+ .get ("input" )
252+ .apply (
253+ "Map rows to Kafka messages" ,
254+ ParDo .of (
255+ new GenericRecordErrorCounterFn (
256+ "Kafka-write-error-counter" ,
257+ toGenericRecordsFn ,
258+ errorSchema ,
259+ handleErrors ))
260+ .withOutputTags (RECORD_OUTPUT_TAG , TupleTagList .of (ERROR_TAG )));
261+ HashMap <String , Object > producerConfig = new HashMap <>(configOverrides );
262+ outputTuple
263+ .get (RECORD_OUTPUT_TAG )
264+ .setCoder (KvCoder .of (ByteArrayCoder .of (), AvroCoder .of (avroSchema )))
265+ .apply (
266+ "Map Rows to GenericRecords" ,
267+ KafkaIO .<byte [], GenericRecord >write ()
268+ .withTopic (configuration .getTopic ())
269+ .withBootstrapServers (configuration .getBootstrapServers ())
270+ .withProducerConfigUpdates (producerConfig )
271+ .withKeySerializer (ByteArraySerializer .class )
272+ .withValueSerializer ((Class ) KafkaAvroSerializer .class ));
273+ } else {
274+ outputTuple =
275+ input
276+ .get ("input" )
277+ .apply (
278+ "Map rows to Kafka messages" ,
279+ ParDo .of (
280+ new ErrorCounterFn (
281+ "Kafka-write-error-counter" , toBytesFn , errorSchema , handleErrors ))
282+ .withOutputTags (OUTPUT_TAG , TupleTagList .of (ERROR_TAG )));
283+
284+ outputTuple
285+ .get (OUTPUT_TAG )
286+ .apply (
287+ KafkaIO .<byte [], byte []>write ()
288+ .withTopic (configuration .getTopic ())
289+ .withBootstrapServers (configuration .getBootstrapServers ())
290+ .withProducerConfigUpdates (
291+ configOverrides == null
292+ ? new HashMap <>()
293+ : new HashMap <String , Object >(configOverrides ))
294+ .withKeySerializer (ByteArraySerializer .class )
295+ .withValueSerializer (ByteArraySerializer .class ));
296+ }
231297
232298 // TODO: include output from KafkaIO Write once updated from PDone
233299 PCollection <Row > errorOutput =
0 commit comments