Skip to content

Commit a3cbc4f

Browse files
committed
Ensure watermark updates when position advances
1 parent 275d39a commit a3cbc4f

File tree

2 files changed

+33
-45
lines changed

2 files changed

+33
-45
lines changed

sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.math.BigDecimal;
2323
import java.math.MathContext;
24+
import java.time.Duration;
2425
import java.util.Collections;
2526
import java.util.HashMap;
2627
import java.util.List;
@@ -55,7 +56,6 @@
5556
import org.apache.beam.sdk.values.TupleTag;
5657
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
5758
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
58-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch;
5959
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
6060
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
6161
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
@@ -140,8 +140,8 @@
140140
* {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically
141141
* by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon
142142
* as the {@link TopicPartition} is removed. For example, the removal could happen at the same time
143-
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that
144-
* case, the {@link ReadFromKafkaDoFn} will still output the fetched records.
143+
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the
144+
* {@link ReadFromKafkaDoFn} will still output the fetched records.
145145
*
146146
* <h4>Stop Reading from Stopped {@link TopicPartition}</h4>
147147
*
@@ -199,11 +199,11 @@ private ReadFromKafkaDoFn(
199199
this.checkStopReadingFn = transform.getCheckStopReadingFn();
200200
this.badRecordRouter = transform.getBadRecordRouter();
201201
this.recordTag = recordTag;
202-
if (transform.getConsumerPollingTimeout() > 0) {
203-
this.consumerPollingTimeout = transform.getConsumerPollingTimeout();
204-
} else {
205-
this.consumerPollingTimeout = DEFAULT_KAFKA_POLL_TIMEOUT;
206-
}
202+
this.consumerPollingTimeout =
203+
Duration.ofSeconds(
204+
transform.getConsumerPollingTimeout() > 0
205+
? transform.getConsumerPollingTimeout()
206+
: DEFAULT_KAFKA_POLL_TIMEOUT);
207207
}
208208

209209
private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
@@ -248,7 +248,7 @@ private static final class SharedStateHolder {
248248

249249
private transient @Nullable LoadingCache<KafkaSourceDescriptor, MovingAvg> avgRecordSizeCache;
250250
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;
251-
@VisibleForTesting final long consumerPollingTimeout;
251+
@VisibleForTesting final Duration consumerPollingTimeout;
252252
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
253253
@VisibleForTesting final DeserializerProvider<V> valueDeserializerProvider;
254254
@VisibleForTesting final Map<String, Object> consumerConfig;
@@ -443,15 +443,17 @@ public ProcessContinuation processElement(
443443
consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
444444
long expectedOffset = tracker.currentRestriction().getFrom();
445445
consumer.seek(kafkaSourceDescriptor.getTopicPartition(), expectedOffset);
446-
ConsumerRecords<byte[], byte[]> rawRecords = ConsumerRecords.empty();
447446

448447
while (true) {
449448
// Fetch the record size accumulator.
450449
final MovingAvg avgRecordSize = avgRecordSizeCache.getUnchecked(kafkaSourceDescriptor);
451-
rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition());
452-
// When there are no records available for the current TopicPartition, self-checkpoint
453-
// and move to process the next element.
454-
if (rawRecords.isEmpty()) {
450+
// Fetch the next records.
451+
final ConsumerRecords<byte[], byte[]> rawRecords =
452+
consumer.poll(this.consumerPollingTimeout);
453+
454+
// No progress when the polling timeout expired.
455+
// Self-checkpoint and move to process the next element.
456+
if (rawRecords == ConsumerRecords.<byte[], byte[]>empty()) {
455457
if (!topicPartitionExists(
456458
kafkaSourceDescriptor.getTopicPartition(),
457459
consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) {
@@ -462,6 +464,9 @@ public ProcessContinuation processElement(
462464
}
463465
return ProcessContinuation.resume();
464466
}
467+
468+
// Visible progress within the consumer polling timeout.
469+
// Partially or fully claim and process records in this batch.
465470
for (ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
466471
if (!tracker.tryClaim(rawRecord.offset())) {
467472
return ProcessContinuation.stop();
@@ -512,6 +517,17 @@ public ProcessContinuation processElement(
512517
}
513518
}
514519

520+
// Non-visible progress within the consumer polling timeout.
521+
// Claim up to the current position.
522+
if (expectedOffset < (expectedOffset = consumer.position(topicPartition))) {
523+
if (!tracker.tryClaim(expectedOffset - 1)) {
524+
return ProcessContinuation.stop();
525+
}
526+
if (timestampPolicy != null) {
527+
updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
528+
}
529+
}
530+
515531
backlogBytes.set(
516532
(long)
517533
(BigDecimal.valueOf(
@@ -531,34 +547,6 @@ private boolean topicPartitionExists(
531547
.anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition()));
532548
}
533549

534-
// see https://github.com/apache/beam/issues/25962
535-
private ConsumerRecords<byte[], byte[]> poll(
536-
Consumer<byte[], byte[]> consumer, TopicPartition topicPartition) {
537-
final Stopwatch sw = Stopwatch.createStarted();
538-
long previousPosition = -1;
539-
java.time.Duration elapsed = java.time.Duration.ZERO;
540-
java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout);
541-
while (true) {
542-
final ConsumerRecords<byte[], byte[]> rawRecords = consumer.poll(timeout.minus(elapsed));
543-
if (!rawRecords.isEmpty()) {
544-
// return as we have found some entries
545-
return rawRecords;
546-
}
547-
if (previousPosition == (previousPosition = consumer.position(topicPartition))) {
548-
// there was no progress on the offset/position, which indicates end of stream
549-
return rawRecords;
550-
}
551-
elapsed = sw.elapsed();
552-
if (elapsed.toMillis() >= timeout.toMillis()) {
553-
// timeout is over
554-
LOG.warn(
555-
"No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.",
556-
consumerPollingTimeout);
557-
return rawRecords;
558-
}
559-
}
560-
}
561-
562550
private TimestampPolicyContext updateWatermarkManually(
563551
TimestampPolicy<K, V> timestampPolicy,
564552
WatermarkEstimator<Instant> watermarkEstimator,

sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,14 +715,14 @@ public void testUnbounded() {
715715
@Test
716716
public void testConstructorWithPollTimeout() {
717717
ReadSourceDescriptors<String, String> descriptors = makeReadSourceDescriptor(consumer);
718-
// default poll timeout = 1 scond
718+
// default poll timeout = 2 seconds
719719
ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
720-
Assert.assertEquals(2L, dofnInstance.consumerPollingTimeout);
720+
Assert.assertEquals(Duration.ofSeconds(2L), dofnInstance.consumerPollingTimeout);
721721
// updated timeout = 5 seconds
722722
descriptors = descriptors.withConsumerPollingTimeout(5L);
723723
ReadFromKafkaDoFn<String, String> dofnInstanceNew =
724724
ReadFromKafkaDoFn.create(descriptors, RECORDS);
725-
Assert.assertEquals(5L, dofnInstanceNew.consumerPollingTimeout);
725+
Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
726726
}
727727

728728
private BoundednessVisitor testBoundedness(

0 commit comments

Comments
 (0)