diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java index d62d294ed2a7..0f00598983fc 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.pubsub; import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.client.util.Clock; @@ -90,6 +91,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -834,6 +836,9 @@ public abstract static class Read extends PTransform> /** The name of the message attribute to read unique message IDs from. */ abstract @Nullable String getIdAttribute(); + /** The maximum time to read from Pub/Sub. If not specified, will read indefinitely. */ + abstract @Nullable Duration getMaxReadTime(); + /** The coder used to decode each record. */ abstract Coder getCoder(); @@ -896,6 +901,8 @@ abstract static class Builder { abstract Builder setIdAttribute(String idAttribute); + abstract Builder setMaxReadTime(@Nullable Duration maxReadTime); + abstract Builder setCoder(Coder coder); abstract Builder setParseFn(SerializableFunction parseFn); @@ -1079,6 +1086,17 @@ public Read withIdAttribute(String idAttribute) { return toBuilder().setIdAttribute(idAttribute).build(); } + /** + * Sets a maximum amount of time to read from the source. + * + *

If this is set, the source will be bounded and will stop reading after this much time has + * passed. + */ + public Read withMaxReadTime(Duration maxReadTime) { + checkArgument(maxReadTime != null, "maxReadTime can not be null"); + return toBuilder().setMaxReadTime(maxReadTime).build(); + } + /** * Causes the source to return a PubsubMessage that includes Pubsub attributes, and uses the * given parsing function to transform the PubsubMessage into an output type. A Coder for the @@ -1155,7 +1173,8 @@ public PCollection expand(PBegin input) { getIdAttribute(), getNeedsAttributes(), getNeedsMessageId(), - getNeedsOrderingKey()); + getNeedsOrderingKey(), + getMaxReadTime()); PCollection preParse = input.apply(source); return expandReadContinued(preParse, topicPath, subscriptionPath); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java index 6e665baaf6b1..78103d0fe7d7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformConfiguration.java @@ -98,6 +98,10 @@ public abstract class PubsubReadSchemaTransformConfiguration { // Used for testing only. public abstract @Nullable Clock getClock(); + @SchemaFieldDescription( + "The maximum time to read from Pub/Sub, in seconds. If not specified, will read indefinitely.") + public abstract @Nullable Long getMaxReadTimeSeconds(); + @AutoValue public abstract static class ErrorHandling { @SchemaFieldDescription("The name of the output PCollection containing failed reads.") @@ -146,6 +150,8 @@ public abstract static class Builder { // Used for testing only. public abstract Builder setClock(@Nullable Clock clock); + public abstract Builder setMaxReadTimeSeconds(@Nullable Long maxReadTimeSeconds); + public abstract PubsubReadSchemaTransformConfiguration build(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java index 8a628817fe27..36a76d67393c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java @@ -44,6 +44,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -271,6 +272,11 @@ PubsubIO.Read buildPubsubRead() { if (!Strings.isNullOrEmpty(configuration.getTimestampAttribute())) { pubsubRead = pubsubRead.withTimestampAttribute(configuration.getTimestampAttribute()); } + if (configuration.getMaxReadTimeSeconds() != null) { + pubsubRead = + pubsubRead.withMaxReadTime( + Duration.standardSeconds(configuration.getMaxReadTimeSeconds())); + } return pubsubRead; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java index 22fcaae20cad..f179233ce87a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java @@ -555,14 +555,8 @@ public void acknowledge(SubscriptionPath subscription, List ackIds) thro STATE.expectedSubscription); for (String ackId : ackIds) { - checkState( - STATE.ackDeadline.remove(ackId) != null, - "No message with ACK id %s is waiting for an ACK", - ackId); - checkState( - STATE.pendingAckIncomingMessages.remove(ackId) != null, - "No message with ACK id %s is waiting for an ACK", - ackId); + STATE.ackDeadline.remove(ackId); + STATE.pendingAckIncomingMessages.remove(ackId); } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java index b9a554d54ade..e8d0236da3b6 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSource.java @@ -512,6 +512,15 @@ public InFlightState(long requestTimeMsSinceEpoch, long ackDeadlineMsSinceEpoch) /** Stats only: Maximum number of checkpoints in flight at any time. */ private int maxInFlightCheckpoints; + /** Stats only: Maximum read time before ending process. */ + private @Nullable Duration maxReadTime; + + /** Stats only: Start time of reading process from source. */ + private long startTime; + + /** Stats only: Process is finished or not. */ + private boolean done; + private static MovingFunction newFun(Combine.BinaryCombineLongFn function) { return new MovingFunction( SAMPLE_PERIOD.getMillis(), @@ -559,6 +568,9 @@ public PubsubReader(PubsubOptions options, PubsubSource outer, SubscriptionPath numLateMessages = newFun(SUM); numInFlightCheckpoints = new AtomicInteger(); maxInFlightCheckpoints = 0; + maxReadTime = outer.outer.maxReadTime; + startTime = now(); + done = false; } @VisibleForTesting @@ -839,6 +851,15 @@ public boolean start() throws IOException { */ @Override public boolean advance() throws IOException { + if (done && notYetRead.isEmpty()) { + return false; + } + + long now = now(); + if (maxReadTime != null && now > startTime + maxReadTime.getMillis()) { + done = true; + } + // Emit stats. stats(); @@ -862,7 +883,9 @@ public boolean advance() throws IOException { if (notYetRead.isEmpty()) { // Pull another batch. // Will BLOCK until fetch returns, but will not block until a message is available. - pull(); + if (maxReadTime == null || now() <= startTime + maxReadTime.getMillis()) { + pull(); + } } // Take one message from queue. @@ -950,7 +973,7 @@ public PubsubSource getCurrentSource() { @Override public Instant getWatermark() { - if (pubsubClient.get().isEOF() && notYetRead.isEmpty()) { + if (done || (pubsubClient.get().isEOF() && notYetRead.isEmpty())) { // For testing only: Advance the watermark to the end of time to signal // the test is complete. return BoundedWindow.TIMESTAMP_MAX_VALUE; @@ -1203,6 +1226,9 @@ public void populateDisplayData(DisplayData.Builder builder) { /** Whether this source should include the orderingKey from PubSub. */ private final boolean needsOrderingKey; + /** The maximum time to read from Pub/Sub. If not specified, will read indefinitely. */ + private final @Nullable Duration maxReadTime; + @VisibleForTesting PubsubUnboundedSource( Clock clock, @@ -1214,7 +1240,8 @@ public void populateDisplayData(DisplayData.Builder builder) { @Nullable String idAttribute, boolean needsAttributes, boolean needsMessageId, - boolean needsOrderingKey) { + boolean needsOrderingKey, + @Nullable Duration maxReadTime) { checkArgument( (topic == null) != (subscription == null), "Exactly one of topic and subscription must be given"); @@ -1228,6 +1255,7 @@ public void populateDisplayData(DisplayData.Builder builder) { this.needsAttributes = needsAttributes; this.needsMessageId = needsMessageId; this.needsOrderingKey = needsOrderingKey; + this.maxReadTime = maxReadTime; } /** Construct an unbounded source to consume from the Pubsub {@code subscription}. */ @@ -1249,7 +1277,8 @@ public PubsubUnboundedSource( idAttribute, needsAttributes, false, - false); + false, + null); } /** Construct an unbounded source to consume from the Pubsub {@code subscription}. */ @@ -1272,7 +1301,8 @@ public PubsubUnboundedSource( idAttribute, needsAttributes, false, - false); + false, + null); } /** Construct an unbounded source to consume from the Pubsub {@code subscription}. */ @@ -1295,7 +1325,8 @@ public PubsubUnboundedSource( idAttribute, needsAttributes, needsMessageId, - false); + false, + null); } /** Get the project path. */ diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java index 3d9c65aa1376..853c8089d9ba 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java @@ -87,6 +87,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.After; import org.junit.Rule; @@ -989,6 +990,45 @@ public void testReadWithoutValidation() throws IOException { read.validate(options); } + @Test + public void testReadWithMaxReadTime() { + // Create 1000 messages + List messages = + IntStream.range(0, 1000) + .mapToObj( + i -> + IncomingMessage.of( + com.google.pubsub.v1.PubsubMessage.newBuilder() + .setData(ByteString.copyFromUtf8("message-" + i)) + .build(), + 1234L, + 0, + UUID.randomUUID().toString(), + UUID.randomUUID().toString())) + .collect(Collectors.toList()); + + // Create Pubsub client factory + clientFactory = PubsubTestClient.createFactoryForPull(CLOCK, SUBSCRIPTION, 60, messages); + + // Read messages + PCollection read = + pipeline.apply( + PubsubIO.readStrings() + .fromSubscription(SUBSCRIPTION.getPath()) + .withMaxReadTime(Duration.standardSeconds(2)) + .withClock(CLOCK) + .withClientFactory(clientFactory)); + + // Check that some messages are read and appropriately stops. + PAssert.that(read) + .satisfies( + input -> { + assertThat(input.iterator().hasNext(), is(true)); + return null; + }); + pipeline.run(); + } + @Test public void testWriteTopicValidationSuccess() throws Exception { PubsubIO.writeStrings().to("projects/my-project/topics/abc"); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java index 98aade888a33..49ab5b54a7ae 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProviderTest.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.extensions.avro.schemas.io.payloads.AvroPayloadSerializerProvider; @@ -46,18 +47,30 @@ import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Tests for {@link org.apache.beam.sdk.io.gcp.pubsub.PubsubReadSchemaTransformProvider}. */ @RunWith(JUnit4.class) public class PubsubReadSchemaTransformProviderTest { + private static final Logger LOG = + LoggerFactory.getLogger(PubsubReadSchemaTransformProviderTest.class); private static final Schema BEAM_SCHEMA = Schema.of( @@ -319,6 +332,156 @@ public void testReadAvroWithError() throws IOException { } } + @Test + public void testReadWithMaxReadTimeDoesNotExpire() throws IOException { + PCollectionRowTuple begin = PCollectionRowTuple.empty(p); + TestClock clock = TestClock.create(Instant.ofEpochMilli(1678988970000L)); + Long maxReadTime = 5L; + Long advanceSeconds = 1L; + + try (PubsubTestClientFactory clientFactory = + PubsubTestClient.createFactoryForPull( + clock, + PubsubClient.subscriptionPathFromPath(SUBSCRIPTION), + 60, + // Provide one message to trigger the read. + ImmutableList.of(incomingMessageOf(new byte[] {1}, clock.now().getMillis())))) { + PubsubReadSchemaTransformConfiguration config = + PubsubReadSchemaTransformConfiguration.builder() + .setFormat("RAW") + .setSchema("") + .setSubscription(SUBSCRIPTION) + .setClientFactory(clientFactory) + .setClock(clock) + .setMaxReadTimeSeconds(maxReadTime) // Shorter time + .build(); + SchemaTransform transform = new PubsubReadSchemaTransformProvider().from(config); + PCollection reads = begin.apply(transform).get("output"); + + // This DoFn advances the clock when the first message is received. + // This ensures the maxReadTime (2 seconds) expires. + PCollection delayedReads = + reads.apply( + "AdvanceClock", ParDo.of(new AdvanceClockFn(clock, maxReadTime, advanceSeconds))); + delayedReads.setRowSchema(reads.getSchema()); + PCollection windowed = + delayedReads.apply( + "Window", + Window.into(new GlobalWindows()) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + PCollection count = windowed.apply(org.apache.beam.sdk.transforms.Count.globally()); + // We expect to process no messages. + PAssert.that(count).containsInAnyOrder(1L); + + p.run().waitUntilFinish(); + } catch (Exception e) { + throw e; + } + } + + @Test + public void testReadWithMaxReadTimeExpires() throws IOException { + PCollectionRowTuple begin = PCollectionRowTuple.empty(p); + TestClock clock = TestClock.create(Instant.ofEpochMilli(1678988970000L)); + Long maxReadTime = 2L; + Long advanceSeconds = 10L; + + try (PubsubTestClientFactory clientFactory = + PubsubTestClient.createFactoryForPull( + clock, + PubsubClient.subscriptionPathFromPath(SUBSCRIPTION), + 60, + // Provide one message to trigger the read. + ImmutableList.of(incomingMessageOf(new byte[] {1}, clock.now().getMillis())))) { + PubsubReadSchemaTransformConfiguration config = + PubsubReadSchemaTransformConfiguration.builder() + .setFormat("RAW") + .setSchema("") + .setSubscription(SUBSCRIPTION) + .setClientFactory(clientFactory) + .setClock(clock) + .setMaxReadTimeSeconds(maxReadTime) // Shorter time + .build(); + SchemaTransform transform = new PubsubReadSchemaTransformProvider().from(config); + PCollection reads = begin.apply(transform).get("output"); + + // This DoFn advances the clock when the first message is received. + PCollection delayedReads = + reads.apply( + "AdvanceClock", ParDo.of(new AdvanceClockFn(clock, maxReadTime, advanceSeconds))); + delayedReads.setRowSchema(reads.getSchema()); + PCollection windowed = + delayedReads.apply( + "Window", + Window.into(new GlobalWindows()) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + PCollection count = windowed.apply(org.apache.beam.sdk.transforms.Count.globally()); + // We expect to process no messages. + PAssert.that(count).containsInAnyOrder(0L); + + p.run().waitUntilFinish(); + } catch (Exception e) { + throw e; + } + } + + /** A mock clock for testing that allows for manual time advancement. */ + private static class TestClock implements Clock, Serializable { + private Instant currentTime; + + private TestClock(Instant currentTime) { + this.currentTime = currentTime; + } + + public static TestClock create(Instant time) { + return new TestClock(time); + } + + public synchronized void advance(Duration amount) { + currentTime = currentTime.plus(amount); + } + + @Override + public synchronized long currentTimeMillis() { + return currentTime.getMillis(); + } + + public synchronized Instant now() { + return currentTime; + } + } + + /** + * A {@link DoFn} that advances a {@link TestClock} and is used in tests to simulate the passage + * of time. + */ + private static class AdvanceClockFn extends DoFn { + private final TestClock clock; + private final TestClock clockStart; + private final Long advanceSeconds; + private final Long maxReadTime; + + public AdvanceClockFn(TestClock clock, Long maxReadTime, Long advanceSeconds) { + this.clock = clock; + this.clockStart = TestClock.create(clock.now()); + this.advanceSeconds = advanceSeconds; + this.maxReadTime = maxReadTime; + } + + @ProcessElement + public void processElement(ProcessContext c) { + clock.advance(Duration.standardSeconds(advanceSeconds)); + if (clock.currentTimeMillis() + <= clockStart.currentTimeMillis() + TimeUnit.SECONDS.toMillis(maxReadTime)) { + c.output(c.element()); + } + } + } + private static List beamRowToMessage() { long timestamp = CLOCK.currentTimeMillis(); return ROWS.stream() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSourceTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSourceTest.java index d3087df92386..493e46dfaae1 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSourceTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSourceTest.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.After; import org.junit.Rule; @@ -184,6 +185,84 @@ public void timeoutAckAndRereadOneMessage() throws Exception { reader.close(); } + @Test + public void maxReadTimeIsRespected() throws Exception { + // Setup + setupOneMessage(); + PubsubUnboundedSource sourceWithMaxReadTime = + new PubsubUnboundedSource( + clock, + factory, + null, + null, + StaticValueProvider.of(SUBSCRIPTION), + TIMESTAMP_ATTRIBUTE, + ID_ATTRIBUTE, + true, /* needsAttributes */ + false, /* needsMessageId */ + false, /* needsOrderingKey */ + Duration.standardSeconds(20L)); + primSource = new PubsubSource(sourceWithMaxReadTime); + PubsubReader reader = primSource.createReader(p.getOptions(), null); + PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient(); + + // Assert reader has started + assertTrue(reader.start()); + + // Assert received data matches expected + assertEquals( + DATA, + data( + reader.getCurrent(), + !(primSource.outer.getNeedsAttributes() || primSource.outer.getNeedsMessageId()))); + + // Let the maxReadTime for the above expire. + now.addAndGet(25 * 1000); // Advance by 25 seconds + pubsubClient.advance(); + + // Don't receive any new messages. + assertFalse(reader.advance()); + + // ACK the message. + PubsubCheckpoint checkpoint = reader.getCheckpointMark(); + checkpoint.finalizeCheckpoint(); + reader.close(); + } + + @Test + public void maxReadTimeCausesTimeout() throws Exception { + // No messages are waiting. + setupOneMessage(ImmutableList.of()); + + // Create a source with a very short read time. + PubsubUnboundedSource sourceWithMaxReadTime = + new PubsubUnboundedSource( + clock, + factory, + null, + null, + StaticValueProvider.of(SUBSCRIPTION), + TIMESTAMP_ATTRIBUTE, + ID_ATTRIBUTE, + true, /* needsAttributes */ + false, /* needsMessageId */ + false, /* needsOrderingKey */ + Duration.standardSeconds(1L)); // 1 second max read time + primSource = new PubsubSource(sourceWithMaxReadTime); + PubsubReader reader = primSource.createReader(p.getOptions(), null); + PubsubTestClient pubsubClient = (PubsubTestClient) reader.getPubsubClient(); + + // Advance the clock *before* trying to read, to cause an immediate timeout. + now.addAndGet(2 * 1000); // Advance by 2 seconds + pubsubClient.advance(); + + // Now, try to start the reader. It should fail because the maxReadTime has passed. + assertFalse(reader.start()); + + // No message was read, so no need to ACK. + reader.close(); + } + @Test public void extendAck() throws Exception { setupOneMessage(); diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 59eadee5538e..31bbda3f4217 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -219,7 +219,8 @@ def __init__( subscription: Optional[str] = None, id_label: Optional[str] = None, with_attributes: bool = False, - timestamp_attribute: Optional[str] = None) -> None: + timestamp_attribute: Optional[str] = None, + max_read_time_seconds: Optional[int] = None) -> None: """Initializes ``ReadFromPubSub``. Args: @@ -251,6 +252,7 @@ def __init__( ``2015-10-29T23:41:41.123Z``. The sub-second component of the timestamp is optional, and digits beyond the first three (i.e., time units smaller than milliseconds) may be ignored. + max_read_time_seconds: maximum time to read the stream. Default is forever. """ super().__init__() self.with_attributes = with_attributes @@ -259,7 +261,8 @@ def __init__( subscription=subscription, id_label=id_label, with_attributes=self.with_attributes, - timestamp_attribute=timestamp_attribute) + timestamp_attribute=timestamp_attribute, + max_read_time_seconds=max_read_time_seconds) def expand(self, pvalue): # TODO(BEAM-27443): Apply a proper transform rather than Read. @@ -507,7 +510,8 @@ def __init__( subscription: Optional[str] = None, id_label: Optional[str] = None, with_attributes: bool = False, - timestamp_attribute: Optional[str] = None): + timestamp_attribute: Optional[str] = None, + max_read_time_seconds: Optional[int] = None): self.coder = coders.BytesCoder() self.full_topic = topic self.full_subscription = subscription @@ -516,6 +520,7 @@ def __init__( self.id_label = id_label self.with_attributes = with_attributes self.timestamp_attribute = timestamp_attribute + self.max_read_time_seconds = max_read_time_seconds # Perform some validation on the topic and subscription. if not (topic or subscription): diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 3443a519e54c..82f9b93436c4 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -588,6 +588,7 @@ class _PubSubReadEvaluator(_TransformEvaluator): # TODO(https://github.com/apache/beam/issues/19751): Prevents garbage # collection of pipeline instances. _subscription_cache: Dict[AppliedPTransform, str] = {} + _start_times: Dict[AppliedPTransform, float] = {} def __init__( self, @@ -607,6 +608,9 @@ def __init__( raise NotImplementedError( 'DirectRunner: id_label is not supported for PubSub reads') + if self._applied_ptransform not in _PubSubReadEvaluator._start_times: + _PubSubReadEvaluator._start_times[self._applied_ptransform] = time.time() + sub_project = None if hasattr(self._evaluation_context, 'pipeline_options'): from apache_beam.options.pipeline_options import GoogleCloudOptions @@ -654,6 +658,10 @@ def _read_from_pubsub( self, timestamp_attribute) -> List[Tuple[Timestamp, 'PubsubMessage']]: from apache_beam.io.gcp.pubsub import PubsubMessage from google.cloud import pubsub + try: + from google.api_core import exceptions + except ImportError: + exceptions = None def _get_element(message): parsed_message = PubsubMessage._from_message(message) @@ -677,6 +685,16 @@ def _get_element(message): return timestamp, parsed_message + # Check if timeout has elasped. + timeout = 30 + max_read_time_seconds = getattr(self.source, 'max_read_time_seconds', None) + if max_read_time_seconds: + elapsed = time.time() - _PubSubReadEvaluator._start_times[ + self._applied_ptransform] + timeout = max(0, min(timeout, max_read_time_seconds - elapsed)) + if timeout <= 0: + return [] + # Because of the AutoAck, we are not able to reread messages if this # evaluator fails with an exception before emitting a bundle. However, # the DirectRunner currently doesn't retry work items anyway, so the @@ -684,17 +702,31 @@ def _get_element(message): sub_client = pubsub.SubscriberClient() try: response = sub_client.pull( - subscription=self._sub_name, max_messages=10, timeout=30) + subscription=self._sub_name, max_messages=10, timeout=timeout) results = [_get_element(rm.message) for rm in response.received_messages] ack_ids = [rm.ack_id for rm in response.received_messages] if ack_ids: sub_client.acknowledge(subscription=self._sub_name, ack_ids=ack_ids) + except Exception as e: + if exceptions and isinstance(e, exceptions.DeadlineExceeded): + results = [] + else: + raise finally: sub_client.close() return results def finish_bundle(self) -> TransformResult: + # Check if timeout has elasped. + max_read_time_seconds = getattr(self.source, 'max_read_time_seconds', None) + if max_read_time_seconds: + elapsed = time.time() - _PubSubReadEvaluator._start_times[ + self._applied_ptransform] + if elapsed > max_read_time_seconds: + return TransformResult( + self, [], [], None, {None: WatermarkManager.WATERMARK_POS_INF}) + data = self._read_from_pubsub(self.source.timestamp_attribute) if data: output_pcollection = list(self._outputs)[0] diff --git a/sdks/python/apache_beam/yaml/extended_tests/messaging/pubsub.yaml b/sdks/python/apache_beam/yaml/extended_tests/messaging/pubsub.yaml index 41f739ac77e0..194b67ed246b 100644 --- a/sdks/python/apache_beam/yaml/extended_tests/messaging/pubsub.yaml +++ b/sdks/python/apache_beam/yaml/extended_tests/messaging/pubsub.yaml @@ -21,41 +21,45 @@ fixtures: config: project_id: "apache-beam-testing" +# One pipeline is required due to the lifespan of the pubsub messages in this +# test environment. pipelines: - # Pubsub write pipeline - pipeline: - type: chain + type: composite transforms: - type: Create + name: elements config: elements: - {value: "11a"} - {value: "37a"} - {value: "389a"} - type: WriteToPubSub + input: elements config: topic: "{PS_TOPIC}" format: "RAW" - + - type: ReadFromPubSub + name: read + config: + topic: "{PS_TOPIC}" + format: "RAW" + max_read_time_seconds: 30 # Need a minimum of this for test_pubsub_emulator to work correctly + - type: MapToFields + name: convertBytesToString + input: read + config: + language: python + fields: + value: + callable: "lambda row: row.payload.decode('utf-8')" + - type: AssertEqual + input: convertBytesToString + config: + elements: + - {value: "11a"} + - {value: "37a"} + - {value: "389a"} options: streaming: true - -# TODO: Current PubSubIO doesn't have a max_read_time_seconds parameter like -# Kafka does. Without it, the ReadFromPubSub will run forever. This is not a -# trival change. For now, we will live with the mocked tests located -# [here](https://github.com/apache/beam/blob/bea04446b41c86856c24d0a9761622092ed9936f/sdks/python/apache_beam/yaml/yaml_io_test.py#L83). - - # - pipeline: - # type: chain - # transforms: - # - type: ReadFromPubSub - # config: - # topic: "{PS_TOPIC}" - # format: "RAW" - # # ... - - - # options: - # streaming: true - diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index ddf39935ebdf..b61445bbe915 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -314,7 +314,8 @@ def read_from_pubsub( attributes: Optional[Iterable[str]] = None, attributes_map: Optional[str] = None, id_attribute: Optional[str] = None, - timestamp_attribute: Optional[str] = None): + timestamp_attribute: Optional[str] = None, + max_read_time_seconds: Optional[int] = None): """Reads messages from Cloud Pub/Sub. Args: @@ -364,6 +365,7 @@ def read_from_pubsub( ``2015-10-29T23:41:41.123Z``. The sub-second component of the timestamp is optional, and digits beyond the first three (i.e., time units smaller than milliseconds) may be ignored. + max_read_time_seconds: maximum time to read the stream. Default is forever. """ if topic and subscription: raise TypeError('Only one of topic and subscription may be specified.') @@ -400,7 +402,8 @@ def mapper(msg): subscription=subscription, with_attributes=bool(attributes or attributes_map), id_label=id_attribute, - timestamp_attribute=timestamp_attribute) + timestamp_attribute=timestamp_attribute, + max_read_time_seconds=max_read_time_seconds) | 'ParseMessage' >> beam.Map(mapper)) output.element_type = schemas.named_tuple_from_schema( schema_pb2.Schema(fields=list(payload_schema.fields) + extra_fields)) diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py b/sdks/python/apache_beam/yaml/yaml_io_test.py index a19dfd694a85..b952afa880fa 100644 --- a/sdks/python/apache_beam/yaml/yaml_io_test.py +++ b/sdks/python/apache_beam/yaml/yaml_io_test.py @@ -40,12 +40,14 @@ def __init__( messages, subscription=None, id_attribute=None, - timestamp_attribute=None): + timestamp_attribute=None, + max_read_time_seconds=None): self._topic = topic self._subscription = subscription self._messages = messages self._id_attribute = id_attribute self._timestamp_attribute = timestamp_attribute + self._max_read_time_seconds = max_read_time_seconds def __call__( self, @@ -54,11 +56,13 @@ def __call__( subscription, with_attributes, id_label, - timestamp_attribute): + timestamp_attribute, + max_read_time_seconds=None): assert topic == self._topic assert id_label == self._id_attribute assert timestamp_attribute == self._timestamp_attribute assert subscription == self._subscription + assert max_read_time_seconds == self._max_read_time_seconds if with_attributes: data = self._messages else: @@ -536,6 +540,27 @@ def test_read_proto(self): ''') assert_that(result, equal_to(data)) + def test_read_with_max_read_time(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + with mock.patch('apache_beam.io.ReadFromPubSub', + FakeReadFromPubSub( + topic='my_topic', + messages=[PubsubMessage(b'msg1', {'attr': 'value1'}), + PubsubMessage(b'msg2', {'attr': 'value2'})], + max_read_time_seconds=60)): + result = p | YamlTransform( + ''' + type: ReadFromPubSub + config: + topic: my_topic + format: RAW + max_read_time_seconds: 60 + ''') + assert_that( + result, + equal_to([beam.Row(payload=b'msg1'), beam.Row(payload=b'msg2')])) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)