2929import java .io .IOException ;
3030import java .io .Serializable ;
3131import java .lang .reflect .InvocationTargetException ;
32+ import java .time .Duration ;
3233import java .util .ArrayList ;
3334import java .util .Arrays ;
3435import java .util .Collection ;
4849import org .apache .beam .sdk .transforms .splittabledofn .WatermarkEstimator ;
4950import org .apache .beam .sdk .transforms .splittabledofn .WatermarkEstimators ;
5051import org .apache .beam .sdk .transforms .windowing .BoundedWindow ;
52+ import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .MoreObjects ;
53+ import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Stopwatch ;
5154import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableMap ;
5255import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .Lists ;
5356import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .Streams ;
6063import org .apache .kafka .connect .storage .OffsetStorageReader ;
6164import org .checkerframework .checker .nullness .qual .Nullable ;
6265import org .joda .time .DateTime ;
63- import org .joda .time .Duration ;
6466import org .joda .time .Instant ;
6567import org .slf4j .Logger ;
6668import org .slf4j .LoggerFactory ;
9092public class KafkaSourceConsumerFn <T > extends DoFn <Map <String , String >, T > {
9193 private static final Logger LOG = LoggerFactory .getLogger (KafkaSourceConsumerFn .class );
9294 public static final String BEAM_INSTANCE_PROPERTY = "beam.parent.instance" ;
95+ private static final Long DEFAULT_POLLING_TIMEOUT = 1000L ;
9396
9497 private final Class <? extends SourceConnector > connectorClass ;
98+ private final DebeziumIO .Read <T > spec ;
9599 private final SourceRecordMapper <T > fn ;
100+ private final Long pollingTimeOut ;
96101
97- private final Long millisecondsToRun ;
98- private final Integer maxRecords ;
99-
100- private static DateTime startTime ;
102+ private transient DateTime startTime ;
101103 private static final Map <String , RestrictionTracker <OffsetHolder , Map <String , Object >>>
102104 restrictionTrackers = new ConcurrentHashMap <>();
103105
104- /**
105- * Initializes the SDF with a time limit.
106- *
107- * @param connectorClass Supported Debezium connector class
108- * @param fn a SourceRecordMapper
109- * @param maxRecords Maximum number of records to fetch before finishing.
110- * @param millisecondsToRun Maximum time to run (in milliseconds)
111- */
112- @ SuppressWarnings ("unchecked" )
113- KafkaSourceConsumerFn (
114- Class <?> connectorClass ,
115- SourceRecordMapper <T > fn ,
116- Integer maxRecords ,
117- Long millisecondsToRun ) {
118- this .connectorClass = (Class <? extends SourceConnector >) connectorClass ;
119- this .fn = fn ;
120- this .maxRecords = maxRecords ;
121- this .millisecondsToRun = millisecondsToRun ;
122- }
123-
124106 /**
125107 * Initializes the SDF to be run indefinitely.
126108 *
127109 * @param connectorClass Supported Debezium connector class
128- * @param fn a SourceRecordMapper
129- * @param maxRecords Maximum number of records to fetch before finishing.
110+ * @param spec a DebeziumIO.Read treansform
130111 */
131- KafkaSourceConsumerFn (Class <?> connectorClass , SourceRecordMapper <T > fn , Integer maxRecords ) {
132- this (connectorClass , fn , maxRecords , null );
112+ KafkaSourceConsumerFn (Class <?> connectorClass , DebeziumIO .Read <T > spec ) {
113+ // this(connectorClass, fn, maxRecords, null);
114+ this .connectorClass = (Class <? extends SourceConnector >) connectorClass ;
115+ this .spec = spec ;
116+ this .fn = spec .getFormatFunction ();
117+ this .pollingTimeOut =
118+ MoreObjects .firstNonNull (spec .getPollingTimeout (), DEFAULT_POLLING_TIMEOUT );
133119 }
134120
135121 @ SuppressFBWarnings ("ST_WRITE_TO_STATIC_FROM_INSTANCE_METHOD" )
136122 @ GetInitialRestriction
137123 public OffsetHolder getInitialRestriction (@ Element Map <String , String > unused )
138124 throws IOException {
139- KafkaSourceConsumerFn .startTime = new DateTime ();
140- return new OffsetHolder (null , null , null , this .maxRecords , this .millisecondsToRun );
125+ return new OffsetHolder (null , null , null , spec .getMaxNumberOfRecords (), spec .getMaxTimeToRun ());
141126 }
142127
143128 @ NewTracker
@@ -211,6 +196,11 @@ private static Instant ensureTimestampWithinBounds(Instant timestamp) {
211196 return timestamp ;
212197 }
213198
199+ @ Setup
200+ public void setup () {
201+ startTime = DateTime .now ();
202+ }
203+
214204 /**
215205 * Process the retrieved element and format it for output. Update all pending
216206 *
@@ -222,39 +212,61 @@ private static Instant ensureTimestampWithinBounds(Instant timestamp) {
222212 * continue processing after 1 second. Otherwise, if we've reached a limit of elements, to
223213 * stop processing.
224214 */
225- @ DoFn . ProcessElement
215+ @ ProcessElement
226216 public ProcessContinuation process (
227217 @ Element Map <String , String > element ,
228218 RestrictionTracker <OffsetHolder , Map <String , Object >> tracker ,
229- OutputReceiver <T > receiver )
230- throws Exception {
219+ OutputReceiver <T > receiver ) {
220+
221+ if (spec .getMaxNumberOfRecords () != null
222+ && tracker .currentRestriction ().fetchedRecords != null
223+ && tracker .currentRestriction ().fetchedRecords >= spec .getMaxNumberOfRecords ()) {
224+ return ProcessContinuation .stop ();
225+ }
226+
231227 Map <String , String > configuration = new HashMap <>(element );
232228
233229 // Adding the current restriction to the class object to be found by the database history
234230 register (tracker );
235231 configuration .put (BEAM_INSTANCE_PROPERTY , this .getHashCode ());
236232
237- SourceConnector connector = connectorClass .getDeclaredConstructor ().newInstance ();
238- connector .start (configuration );
239-
240- SourceTask task = (SourceTask ) connector .taskClass ().getDeclaredConstructor ().newInstance ();
233+ SourceConnector connector ;
234+ SourceTask task ;
235+ try {
236+ connector = connectorClass .getDeclaredConstructor ().newInstance ();
237+ connector .start (configuration );
238+ task = (SourceTask ) connector .taskClass ().getDeclaredConstructor ().newInstance ();
239+ } catch (InvocationTargetException
240+ | InstantiationException
241+ | IllegalAccessException
242+ | NoSuchMethodException e ) {
243+ throw new RuntimeException (e );
244+ }
241245
246+ Duration remainingTimeout = Duration .ofMillis (pollingTimeOut );
242247 try {
243248 Map <String , ?> consumerOffset = tracker .currentRestriction ().offset ;
244249 LOG .debug ("--------- Consumer offset from Debezium Tracker: {}" , consumerOffset );
245250
246- task .initialize (new BeamSourceTaskContext (tracker . currentRestriction (). offset ));
251+ task .initialize (new BeamSourceTaskContext (consumerOffset ));
247252 task .start (connector .taskConfigs (1 ).get (0 ));
253+ final Stopwatch pollTimer = Stopwatch .createUnstarted ();
248254
249- List <SourceRecord > records = task .poll ();
255+ while (Duration .ZERO .compareTo (remainingTimeout ) < 0 ) {
256+ pollTimer .reset ().start ();
257+ List <SourceRecord > records = task .poll ();
250258
251- if (records == null ) {
252- LOG .debug ("-------- Pulled records null" );
253- return ProcessContinuation .stop ();
254- }
259+ try {
260+ remainingTimeout = remainingTimeout .minus (pollTimer .elapsed ());
261+ } catch (ArithmeticException e ) {
262+ remainingTimeout = Duration .ZERO ;
263+ }
264+
265+ if (records == null || records .isEmpty ()) {
266+ LOG .debug ("-------- Pulled records null or empty" );
267+ break ;
268+ }
255269
256- LOG .debug ("-------- {} records found" , records .size ());
257- while (records != null && !records .isEmpty ()) {
258270 for (SourceRecord record : records ) {
259271 LOG .debug ("-------- Record found: {}" , record );
260272
@@ -272,7 +284,6 @@ public ProcessContinuation process(
272284 receiver .outputWithTimestamp (json , recordInstant );
273285 }
274286 task .commit ();
275- records = task .poll ();
276287 }
277288 } catch (Exception ex ) {
278289 throw new RuntimeException ("Error occurred when consuming changes from Database. " , ex );
@@ -283,12 +294,14 @@ public ProcessContinuation process(
283294 task .stop ();
284295 }
285296
286- long elapsedTime = System . currentTimeMillis () - KafkaSourceConsumerFn . startTime . getMillis ();
287- if ( millisecondsToRun != null && millisecondsToRun > 0 && elapsedTime >= millisecondsToRun ) {
288- return ProcessContinuation . stop ();
289- } else {
290- return ProcessContinuation . resume (). withResumeDelay ( Duration . standardSeconds ( 1 ));
297+ if ( spec . getMaxTimeToRun () != null && spec . getMaxTimeToRun () > 0 ) {
298+ long elapsedTime = System . currentTimeMillis () - startTime . getMillis ();
299+ if ( elapsedTime >= spec . getMaxTimeToRun ()) {
300+ return ProcessContinuation . stop ();
301+ }
291302 }
303+ return ProcessContinuation .resume ()
304+ .withResumeDelay (org .joda .time .Duration .millis (remainingTimeout .toMillis ()));
292305 }
293306
294307 public String getHashCode () {
@@ -418,41 +431,29 @@ static class OffsetTracker extends RestrictionTracker<OffsetHolder, Map<String,
418431 /**
419432 * Overriding {@link #tryClaim} in order to stop fetching records from the database.
420433 *
421- * <p>This works on two different ways:
422- *
423- * <h3>Number of records</h3>
424- *
425- * <p>This is the default behavior. Once the specified number of records has been reached, it
426- * will stop fetching them.
427- *
428- * <h3>Time based</h3>
429- *
430- * User may specify the amount of time the connector to be kept alive. Please see {@link
431- * KafkaSourceConsumerFn} for more details on this.
434+ * <p>If number of record has been set, once the specified number of records has been reached,
435+ * it will stop fetching them.
432436 *
433437 * @param position Currently not used
434438 * @return boolean
435439 */
436440 @ Override
437441 public boolean tryClaim (Map <String , Object > position ) {
438442 LOG .debug ("-------------- Claiming {} used to have: {}" , position , restriction .offset );
439- long elapsedTime = System .currentTimeMillis () - startTime .getMillis ();
440443 int fetchedRecords =
441- this .restriction .fetchedRecords == null ? 0 : this .restriction .fetchedRecords + 1 ;
444+ this .restriction .fetchedRecords == null ? 0 : this .restriction .fetchedRecords ;
442445 LOG .debug ("------------Fetched records {} / {}" , fetchedRecords , this .restriction .maxRecords );
443- LOG .debug (
444- "-------------- Time running: {} / {}" , elapsedTime , (this .restriction .millisToRun ));
445446 this .restriction .offset = position ;
446- this .restriction .fetchedRecords = fetchedRecords ;
447447 LOG .debug ("-------------- History: {}" , this .restriction .history );
448448
449- // If we've reached the maximum number of records OR the maximum time, we reject
450- // the attempt to claim.
451- // If we've reached neither, then we continue approve the claim.
452- return (this .restriction .maxRecords == null || fetchedRecords < this .restriction .maxRecords )
453- && (this .restriction .millisToRun == null
454- || this .restriction .millisToRun == -1
455- || elapsedTime < this .restriction .millisToRun );
449+ // If we've reached the maximum number of records, we reject the attempt to claim.
450+ // Otherwise, we approve the claim.
451+ boolean claimed =
452+ (this .restriction .maxRecords == null || fetchedRecords < this .restriction .maxRecords );
453+ if (claimed ) {
454+ this .restriction .fetchedRecords = fetchedRecords + 1 ;
455+ }
456+ return claimed ;
456457 }
457458
458459 @ Override
0 commit comments