diff --git a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java index 319dd31f5..5c9f9e171 100644 --- a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java +++ b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java @@ -63,6 +63,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -95,6 +96,7 @@ final class StreamingSubscriberConnection extends AbstractApiService implements private final SubscriberStub subscriberStub; private final int channelAffinity; + private final long protocolVersion; private final String subscription; private final SubscriptionName subscriptionNameObject; private final ScheduledExecutorService systemExecutor; @@ -127,6 +129,17 @@ final class StreamingSubscriberConnection extends AbstractApiService implements private OpenTelemetryPubsubTracer tracer = new OpenTelemetryPubsubTracer(null, false); private final SubscriberShutdownSettings subscriberShutdownSettings; + private final boolean enableKeepalive; + private static final long KEEP_ALIVE_SUPPORT_VERSION = 1; + private static final Duration CLIENT_PING_INTERVAL = Duration.ofSeconds(30); + private ScheduledFuture pingSchedulerHandle; + + private static final Duration SERVER_MONITOR_INTERVAL = Duration.ofSeconds(10); + private static final Duration SERVER_PING_TIMEOUT_DURATION = Duration.ofSeconds(15); + private final AtomicLong lastServerResponseTime; + private final AtomicLong lastClientPingTime; + private ScheduledFuture serverMonitorHandle; + private StreamingSubscriberConnection(Builder builder) { subscription = builder.subscription; subscriptionNameObject = SubscriptionName.parse(builder.subscription); @@ -154,6 +167,7 @@ private StreamingSubscriberConnection(Builder builder) { subscriberStub = builder.subscriberStub; channelAffinity = builder.channelAffinity; + protocolVersion = builder.protocolVersion; MessageDispatcher.Builder messageDispatcherBuilder; if (builder.receiver != null) { @@ -190,6 +204,9 @@ private StreamingSubscriberConnection(Builder builder) { flowControlSettings = builder.flowControlSettings; useLegacyFlowControl = builder.useLegacyFlowControl; + enableKeepalive = protocolVersion >= KEEP_ALIVE_SUPPORT_VERSION; + lastServerResponseTime = new AtomicLong(clock.nanoTime()); + lastClientPingTime = new AtomicLong(-1L); } public StreamingSubscriberConnection setExactlyOnceDeliveryEnabled( @@ -218,6 +235,12 @@ protected void doStop() { } finally { lock.unlock(); } + + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } + runShutdown(); notifyStopped(); } @@ -266,6 +289,10 @@ public void onStart(StreamController controller) { @Override public void onResponse(StreamingPullResponse response) { + if (enableKeepalive) { + lastServerResponseTime.set(clock.nanoTime()); + } + channelReconnectBackoffMillis.set(INITIAL_CHANNEL_RECONNECT_BACKOFF.toMillis()); boolean exactlyOnceDeliveryEnabledResponse = @@ -295,11 +322,19 @@ public void onResponse(StreamingPullResponse response) { @Override public void onError(Throwable t) { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } errorFuture.setException(t); } @Override public void onComplete() { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } logger.fine("Streaming pull terminated successfully!"); errorFuture.set(null); } @@ -336,6 +371,7 @@ private void initialize() { this.useLegacyFlowControl ? 0 : valueOrZero(flowControlSettings.getMaxOutstandingRequestBytes())) + .setProtocolVersion(protocolVersion) .build()); /** @@ -350,6 +386,13 @@ private void initialize() { lock.unlock(); } + if (enableKeepalive) { + lastServerResponseTime.set(clock.nanoTime()); + lastClientPingTime.set(-1L); + startClientPinger(); + startServerMonitor(); + } + ApiFutures.addCallback( errorFuture, new ApiFutureCallback() { @@ -366,6 +409,10 @@ public void onSuccess(@Nullable Void result) { @Override public void onFailure(Throwable cause) { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } if (!isAlive()) { // we don't care about subscription failures when we're no longer running. logger.log(Level.FINE, "pull failure after service no longer running", cause); @@ -410,6 +457,100 @@ private boolean isAlive() { return state == State.RUNNING || state == State.STARTING; } + private void startClientPinger() { + if (pingSchedulerHandle != null) { + pingSchedulerHandle.cancel(false); + } + + pingSchedulerHandle = + systemExecutor.scheduleAtFixedRate( + () -> { + try { + lock.lock(); + try { + if (clientStream != null && isAlive()) { + clientStream.send(StreamingPullRequest.newBuilder().build()); + lastClientPingTime.set(clock.nanoTime()); + logger.log(Level.FINEST, "Sent client keepalive ping"); + } + } finally { + lock.unlock(); + } + } catch (Exception e) { + logger.log(Level.FINE, "Error sending client keepalive ping", e); + } + }, + 0, + CLIENT_PING_INTERVAL.getSeconds(), + TimeUnit.SECONDS); + } + + private void stopClientPinger() { + if (pingSchedulerHandle != null) { + pingSchedulerHandle.cancel(false); + pingSchedulerHandle = null; + } + } + + private void startServerMonitor() { + if (serverMonitorHandle != null) { + serverMonitorHandle.cancel(false); + } + + serverMonitorHandle = + systemExecutor.scheduleAtFixedRate( + () -> { + try { + if (!isAlive()) { + return; + } + + long now = clock.nanoTime(); + long lastResponse = lastServerResponseTime.get(); + long lastPing = lastClientPingTime.get(); + + if (lastPing <= lastResponse) { + return; + } + + Duration elapsedSincePing = Duration.ofNanos(now - lastPing); + if (elapsedSincePing.compareTo(SERVER_PING_TIMEOUT_DURATION) < 0) { + return; + } + + logger.log( + Level.WARNING, + "No response from server for {0} seconds since last ping. Closing stream.", + elapsedSincePing.getSeconds()); + + lock.lock(); + try { + if (clientStream != null) { + clientStream.closeSendWithError( + Status.UNAVAILABLE + .withDescription("Keepalive timeout with server") + .asException()); + } + } finally { + lock.unlock(); + } + stopServerMonitor(); + } catch (Exception e) { + logger.log(Level.FINE, "Error in server keepalive monitor", e); + } + }, + SERVER_MONITOR_INTERVAL.getSeconds(), + SERVER_MONITOR_INTERVAL.getSeconds(), + TimeUnit.SECONDS); + } + + private void stopServerMonitor() { + if (serverMonitorHandle != null) { + serverMonitorHandle.cancel(false); + serverMonitorHandle = null; + } + } + public void setResponseOutstandingMessages(AckResponse ackResponse) { // We will close the futures with ackResponse - if there are multiple references to the same // future they will be handled appropriately @@ -769,6 +910,7 @@ public static final class Builder { private Distribution ackLatencyDistribution; private SubscriberStub subscriberStub; private int channelAffinity; + private long protocolVersion; private FlowController flowController; private FlowControlSettings flowControlSettings; private boolean useLegacyFlowControl; @@ -840,6 +982,11 @@ public Builder setChannelAffinity(int channelAffinity) { return this; } + public Builder setProtocolVersion(long protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + public Builder setFlowController(FlowController flowController) { this.flowController = flowController; return this; diff --git a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/Subscriber.java b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/Subscriber.java index c0779ff29..ce9bc6f15 100644 --- a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/Subscriber.java +++ b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/Subscriber.java @@ -144,6 +144,7 @@ public class Subscriber extends AbstractApiService implements SubscriberInterfac private final boolean maxDurationPerAckExtensionDefaultUsed; private final java.time.Duration minDurationPerAckExtension; private final boolean minDurationPerAckExtensionDefaultUsed; + private final long protocolVersion; // The ExecutorProvider used to generate executors for processing messages. private final ExecutorProvider executorProvider; @@ -182,6 +183,7 @@ private Subscriber(Builder builder) { maxDurationPerAckExtensionDefaultUsed = builder.maxDurationPerAckExtensionDefaultUsed; minDurationPerAckExtension = builder.minDurationPerAckExtension; minDurationPerAckExtensionDefaultUsed = builder.minDurationPerAckExtensionDefaultUsed; + protocolVersion = builder.protocolVersion; clock = builder.clock.isPresent() ? builder.clock.get() : CurrentMillisClock.getDefaultClock(); @@ -428,6 +430,7 @@ private void startStreamingConnections() { .setEnableOpenTelemetryTracing(enableOpenTelemetryTracing) .setTracer(tracer) .setSubscriberShutdownSettings(subscriberShutdownSettings) + .setProtocolVersion(protocolVersion) .build(); streamingSubscriberConnections.add(streamingSubscriberConnection); @@ -548,6 +551,8 @@ public static final class Builder { private boolean enableOpenTelemetryTracing = false; private OpenTelemetry openTelemetry = null; + private long protocolVersion = 0L; + private SubscriberShutdownSettings subscriberShutdownSettings = SubscriberShutdownSettings.newBuilder().build(); @@ -771,6 +776,12 @@ Builder setClock(ApiClock clock) { return this; } + /** Gives the ability to override the protocol version */ + public Builder setProtocolVersion(long protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + /** * OpenTelemetry will be enabled if setEnableOpenTelemetry is true and and instance of * OpenTelemetry has been provied. Warning: traces are subject to change. The name and diff --git a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/FakeScheduledExecutorService.java b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/FakeScheduledExecutorService.java index 65e199e92..b17eaddb0 100644 --- a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/FakeScheduledExecutorService.java +++ b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/FakeScheduledExecutorService.java @@ -55,14 +55,14 @@ public FakeClock getClock() { public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { return schedulePendingCallable( new PendingCallable<>( - Duration.ofMillis(unit.toMillis(delay)), command, PendingCallableType.NORMAL)); + Duration.ofMillis(unit.toMillis(delay)), command, null, PendingCallableType.NORMAL)); } @Override public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { return schedulePendingCallable( new PendingCallable<>( - Duration.ofMillis(unit.toMillis(delay)), callable, PendingCallableType.NORMAL)); + Duration.ofMillis(unit.toMillis(delay)), callable, null, PendingCallableType.NORMAL)); } @Override @@ -72,6 +72,7 @@ public ScheduledFuture scheduleAtFixedRate( new PendingCallable<>( Duration.ofMillis(unit.toMillis(initialDelay)), command, + Duration.ofMillis(unit.toMillis(period)), PendingCallableType.FIXED_RATE)); } @@ -82,6 +83,7 @@ public ScheduledFuture scheduleWithFixedDelay( new PendingCallable<>( Duration.ofMillis(unit.toMillis(initialDelay)), command, + Duration.ofMillis(unit.toMillis(delay)), PendingCallableType.FIXED_DELAY)); } @@ -212,13 +214,15 @@ enum PendingCallableType { class PendingCallable implements Comparable> { Instant creationTime = Instant.ofEpochMilli(clock.millisTime()); Duration delay; + Duration period; Callable pendingCallable; SettableFuture future = SettableFuture.create(); AtomicBoolean cancelled = new AtomicBoolean(false); AtomicBoolean done = new AtomicBoolean(false); PendingCallableType type; - PendingCallable(Duration delay, final Runnable runnable, PendingCallableType type) { + PendingCallable( + Duration delay, final Runnable runnable, Duration period, PendingCallableType type) { pendingCallable = new Callable() { @Override @@ -229,12 +233,15 @@ public T call() { }; this.type = type; this.delay = delay; + this.period = period; } - PendingCallable(Duration delay, Callable callable, PendingCallableType type) { + PendingCallable( + Duration delay, Callable callable, Duration period, PendingCallableType type) { pendingCallable = callable; this.type = type; this.delay = delay; + this.period = period; } private Instant getScheduledTime() { @@ -305,10 +312,12 @@ T call() { break; case FIXED_DELAY: this.creationTime = Instant.ofEpochMilli(clock.millisTime()); + this.delay = period; schedulePendingCallable(this); break; case FIXED_RATE: this.creationTime = this.creationTime.plus(delay); + this.delay = period; schedulePendingCallable(this); break; default: diff --git a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java index f79825d85..6979963a5 100644 --- a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java +++ b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java @@ -28,12 +28,18 @@ import com.google.api.gax.core.Distribution; import com.google.api.gax.grpc.GrpcStatusCode; import com.google.api.gax.rpc.ApiException; +import com.google.api.gax.rpc.BidiStreamingCallable; +import com.google.api.gax.rpc.ClientStream; +import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.StatusCode; +import com.google.api.gax.rpc.StreamController; import com.google.cloud.pubsub.v1.stub.SubscriberStub; import com.google.common.collect.Lists; import com.google.protobuf.Any; import com.google.pubsub.v1.AcknowledgeRequest; import com.google.pubsub.v1.ModifyAckDeadlineRequest; +import com.google.pubsub.v1.StreamingPullRequest; +import com.google.pubsub.v1.StreamingPullResponse; import com.google.rpc.ErrorInfo; import com.google.rpc.Status; import io.grpc.Status.Code; @@ -44,11 +50,13 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; +import org.mockito.ArgumentCaptor; /** Tests for {@link StreamingSubscriberConnection}. */ public class StreamingSubscriberConnectionTest { @@ -86,6 +94,10 @@ public class StreamingSubscriberConnectionTest { private static Duration ACK_EXPIRATION_PADDING_DEFAULT_DURATION = Duration.ofSeconds(10); private static int MAX_DURATION_PER_ACK_EXTENSION_DEFAULT_SECONDS = 10; + private static final long KEEP_ALIVE_SUPPORT_VERSION = 1; + private static final Duration CLIENT_PING_INTERVAL = Duration.ofSeconds(30); + private static final Duration MAX_ACK_EXTENSION_PERIOD = Duration.ofMinutes(60); + @Before public void setUp() { systemExecutor = new FakeScheduledExecutorService(); @@ -670,6 +682,155 @@ public void testMaxPerRequestChanges() { } } + @Test + public void testClientPinger_pingSent() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(any(ResponseObserver.class), any())) + .thenReturn(mockClientStream); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(StreamingPullRequest.class); + // 1 initial request + 3 pings + verify(mockClientStream, times(4)).send(requestCaptor.capture()); + List requests = requestCaptor.getAllValues(); + + StreamingPullRequest initialRequest = requests.get(0); + assertEquals(MOCK_SUBSCRIPTION_NAME, initialRequest.getSubscription()); + assertEquals(KEEP_ALIVE_SUPPORT_VERSION, initialRequest.getProtocolVersion()); + assertEquals(0, initialRequest.getMaxOutstandingMessages()); + + StreamingPullRequest firstPing = requests.get(1); + assertEquals(StreamingPullRequest.getDefaultInstance(), firstPing); + + StreamingPullRequest secondPing = requests.get(2); + assertEquals(StreamingPullRequest.getDefaultInstance(), secondPing); + + streamingSubscriberConnection.stopAsync(); + streamingSubscriberConnection.awaitTerminated(); + + // No more pings + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + verify(mockClientStream, times(4)).send(any(StreamingPullRequest.class)); + } + + @Test + public void testClientPinger_pingsNotSentWhenDisabled() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(any(ResponseObserver.class), any())) + .thenReturn(mockClientStream); + + StreamingSubscriberConnection streamingSubscriberConnection = + getStreamingSubscriberConnection(false); // keepalive disabled + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + // Initial request. + verify(mockClientStream, times(1)).send(any(StreamingPullRequest.class)); + + // No pings + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + + verify(mockClientStream, times(1)).send(any(StreamingPullRequest.class)); + } + + @Test + public void testServerMonitor_timesOut() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + ArgumentCaptor> observerCaptor = + ArgumentCaptor.forClass(ResponseObserver.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(observerCaptor.capture(), any())) + .thenReturn(mockClientStream); + + // fail pings after the first one to ensure timeout occurs + AtomicInteger pingCount = new AtomicInteger(0); + doAnswer( + (invocation) -> { + StreamingPullRequest req = invocation.getArgument(0); + // Pings are empty requests + if (req.getSubscription().isEmpty()) { + if (pingCount.incrementAndGet() > 2) { // allow first 2 pings + throw new RuntimeException("ping failed"); + } + } + return null; + }) + .when(mockClientStream) + .send(any(StreamingPullRequest.class)); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + ResponseObserver observer = observerCaptor.getValue(); + StreamController mockController = mock(StreamController.class); + observer.onStart(mockController); + + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + verify(mockClientStream, never()).closeSendWithError(any(Exception.class)); + + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(mockClientStream, times(1)).closeSendWithError(exceptionCaptor.capture()); + StatusException exception = (StatusException) exceptionCaptor.getValue(); + assertEquals(Code.UNAVAILABLE, exception.getStatus().getCode()); + assertEquals("Keepalive timeout with server", exception.getStatus().getDescription()); + } + + @Test + public void testServerMonitor_doesNotTimeOutIfResponseReceived() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + ArgumentCaptor> observerCaptor = + ArgumentCaptor.forClass(ResponseObserver.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(observerCaptor.capture(), any())) + .thenReturn(mockClientStream); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + ResponseObserver observer = observerCaptor.getValue(); + StreamController mockController = mock(StreamController.class); + observer.onStart(mockController); + + // t=30s: ping sent. + // t=40s: response received. + // t=45s: monitor check. lastPing=30, lastResponse=40. lastPing>lastResponse is false -> no + // timeout. + systemExecutor.advanceTime(Duration.ofSeconds(40)); + observer.onResponse(StreamingPullResponse.getDefaultInstance()); + systemExecutor.advanceTime(Duration.ofSeconds(20)); // to t=60s + observer.onResponse(StreamingPullResponse.getDefaultInstance()); + + verify(mockClientStream, never()).closeSendWithError(any(Exception.class)); + } + private StreamingSubscriberConnection getStreamingSubscriberConnection( boolean exactlyOnceDeliveryEnabled) { StreamingSubscriberConnection streamingSubscriberConnection = @@ -682,11 +843,21 @@ private StreamingSubscriberConnection getStreamingSubscriberConnection( return streamingSubscriberConnection; } + private StreamingSubscriberConnection getKeepaliveStreamingSubscriberConnection() { + StreamingSubscriberConnection streamingSubscriberConnection = + getStreamingSubscriberConnectionFromBuilder( + StreamingSubscriberConnection.newBuilder(mock(MessageReceiverWithAckResponse.class)) + .setProtocolVersion(KEEP_ALIVE_SUPPORT_VERSION)); + + return streamingSubscriberConnection; + } + private StreamingSubscriberConnection getStreamingSubscriberConnectionFromBuilder( StreamingSubscriberConnection.Builder builder) { return builder .setSubscription(MOCK_SUBSCRIPTION_NAME) .setAckExpirationPadding(ACK_EXPIRATION_PADDING_DEFAULT_DURATION) + .setMaxAckExtensionPeriod(MAX_ACK_EXTENSION_PERIOD) .setAckLatencyDistribution(mock(Distribution.class)) .setSubscriberStub(mockSubscriberStub) .setChannelAffinity(0)