diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index cf46e1f984dc..ed99ae1bbd6f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,13 +17,22 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.PrintWriter; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -37,9 +46,7 @@ import org.apache.beam.vendor.grpc.v1p69p0.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.checkerframework.checker.nullness.qual.NonNull; -import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -75,11 +82,10 @@ public abstract class AbstractWindmillStream implements Win // shutdown. private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; - private static final String NOT_SHUTDOWN = "not shutdown"; protected final Sleeper sleeper; private final Logger logger; - private final ExecutorService executor; + private final ScheduledExecutorService executor; private final BackOff backoff; private final CountDownLatch finishLatch; private final Set> streamRegistry; @@ -89,6 +95,7 @@ public abstract class AbstractWindmillStream implements Win private final Function, TerminatingStreamObserver> physicalStreamFactory; protected final long physicalStreamDeadlineSeconds; + private final Duration halfClosePhysicalStreamAfter; private final ResettableThrowingStreamObserver requestObserver; private final StreamDebugMetrics debugMetrics; @@ -106,6 +113,17 @@ public abstract class AbstractWindmillStream implements Win @GuardedBy("this") protected @Nullable PhysicalStreamHandler currentPhysicalStream; + @GuardedBy("this") + @Nullable + Future halfCloseFuture = null; + + // Physical streams that have been half-closed and are waiting for responses or stream failure. + @GuardedBy("this") + protected final Set closingPhysicalStreams; + + private final Set closingPhysicalStreamsForDebug = + Collections.newSetFromMap(new ConcurrentHashMap()); + // Generally the same as currentPhysicalStream, set under synchronization of this but can be read // without. private final AtomicReference currentPhysicalStreamForDebug = @@ -114,25 +132,33 @@ public abstract class AbstractWindmillStream implements Win @GuardedBy("this") private boolean started; + // If halfClosePhysicalStream is non-zero, substreams created for the logical + // AbstractWindmillStream + // will be half-closed and a new physical stream will be created after this duraction. protected AbstractWindmillStream( Logger logger, - String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, - String backendWorkerToken) { + String backendWorkerToken, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { + checkArgument(!halfClosePhysicalStreamAfter.isNegative()); this.backendWorkerToken = backendWorkerToken; this.physicalStreamFactory = (StreamObserver observer) -> streamObserverFactory.from(clientFactory, observer); this.physicalStreamDeadlineSeconds = streamObserverFactory.getDeadlineSeconds(); - this.executor = - Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat(createThreadName(debugStreamType, backendWorkerToken)) - .build()); + if (!halfClosePhysicalStreamAfter.isZero() + && halfClosePhysicalStreamAfter.compareTo(Duration.ofSeconds(physicalStreamDeadlineSeconds)) + >= 0) { + logger.debug("Not attempting to half-close cleanly as stream deadline is shorter."); + halfClosePhysicalStreamAfter = Duration.ZERO; + } + this.halfClosePhysicalStreamAfter = halfClosePhysicalStreamAfter; + this.closingPhysicalStreams = Collections.newSetFromMap(new IdentityHashMap<>()); + this.executor = executor; this.backoff = backoff; this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; @@ -147,12 +173,6 @@ protected AbstractWindmillStream( this.debugMetrics = StreamDebugMetrics.create(); } - private static String createThreadName(String streamType, String backendWorkerToken) { - return !backendWorkerToken.isEmpty() - ? String.format("%s-%s-WindmillStream-thread", streamType, backendWorkerToken) - : String.format("%s-WindmillStream-thread", streamType); - } - /** Represents a physical grpc stream that is part of the logical windmill stream. */ protected abstract class PhysicalStreamHandler { @@ -178,11 +198,23 @@ protected abstract class PhysicalStreamHandler { public abstract void appendHtml(PrintWriter writer); private final StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.create(); + + @Override + public final boolean equals(@Nullable Object obj) { + return this == obj; + } + + @Override + public final int hashCode() { + return System.identityHashCode(this); + } } + /* Constructs and returns a new handler to be associated with a physical stream. */ protected abstract PhysicalStreamHandler newResponseHandler(); - protected abstract void onNewStream() throws WindmillStreamShutdownException; + protected abstract void onFlushPending(boolean isNewStream) + throws WindmillStreamShutdownException; /** Try to send a request to the server. Returns true if the request was successfully sent. */ @CanIgnoreReturnValue @@ -214,54 +246,68 @@ public final void start() { } if (shouldStartStream) { + // Add the stream to the registry after it has been fully constructed. + streamRegistry.add(this); startStream(); } } /** Starts the underlying stream. */ private void startStream() { - // Add the stream to the registry after it has been fully constructed. - streamRegistry.add(this); while (true) { @NonNull PhysicalStreamHandler streamHandler = newResponseHandler(); - try { - synchronized (this) { + synchronized (this) { + try { + checkState(currentPhysicalStream == null, "Overwriting existing physical stream"); + checkState(halfCloseFuture == null, "Unexpected half-close future"); + if (isShutdown) { + // No need to start the stream. shutdown() or onPhysicalStreamCompletion will be + // responsible for completing shutdown. + return; + } debugMetrics.recordStart(); streamHandler.streamDebugMetrics.recordStart(); currentPhysicalStream = streamHandler; currentPhysicalStreamForDebug.set(currentPhysicalStream); requestObserver.reset(physicalStreamFactory.apply(new ResponseObserver(streamHandler))); - onNewStream(); + onFlushPending(true); if (clientClosed) { - halfClose(); + // The logical stream is half-closed so after flushing the remaining requests close the + // physical stream. + streamHandler.streamDebugMetrics.recordHalfClose(); + requestObserver.onCompleted(); + } else if (!halfClosePhysicalStreamAfter.isZero()) { + halfCloseFuture = + executor.schedule( + () -> onHalfClosePhysicalStreamTimeout(streamHandler), + halfClosePhysicalStreamAfter.getSeconds(), + TimeUnit.SECONDS); } return; - } - } catch (WindmillStreamShutdownException e) { - // shutdown() is responsible for cleaning up pending requests. - logger.debug("Stream was shutdown while creating new stream.", e); - break; - } catch (Exception e) { - logger.error("Failed to create new stream, retrying: ", e); - try { - long sleep = backoff.nextBackOffMillis(); - debugMetrics.recordSleep(sleep); - sleeper.sleep(sleep); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - logger.info( - "Interrupted during {} creation backoff. The stream will not be created.", - getClass()); - // Shutdown the stream to clean up any dangling resources and pending requests. - shutdown(); + } catch (WindmillStreamShutdownException e) { + logger.debug("Stream was shutdown while creating new stream.", e); + clearCurrentPhysicalStream(true); break; + } catch (Exception e) { + logger.error("Failed to create new stream, retrying: ", e); + clearCurrentPhysicalStream(true); + debugMetrics.recordRestartReason("Failed to create new stream, retrying: " + e); } } + // Backoff outside the synchronized block. + try { + long sleep = backoff.nextBackOffMillis(); + debugMetrics.recordSleep(sleep); + sleeper.sleep(sleep); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + logger.info( + "Interrupted during {} creation backoff. The stream will not be created.", getClass()); + // Shutdown the stream to clean up any dangling resources and pending requests. + shutdown(); + break; + } } - - // We were never able to start the stream, remove it from the stream registry. Otherwise, it is - // removed when closed. - streamRegistry.remove(this); } /** @@ -317,23 +363,6 @@ public final void maybeScheduleHealthCheck(Instant lastSendThreshold) { */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - - @Nullable PhysicalStreamHandler currentHandler = currentPhysicalStreamForDebug.get(); - if (currentHandler != null) { - writer.format("Physical stream: "); - currentHandler.appendHtml(writer); - StreamDebugMetrics.Snapshot summaryMetrics = - currentHandler.streamDebugMetrics.getSummaryMetrics(); - if (summaryMetrics.isClientClosed()) { - writer.write(" client closed"); - } - writer.format( - " current stream is %dms old, last send %dms, last response %dms\n", - summaryMetrics.streamAge(), - summaryMetrics.timeSinceLastSend(), - summaryMetrics.timeSinceLastResponse()); - } - StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); summaryMetrics .restartMetrics() @@ -356,13 +385,44 @@ public final void appendSummaryHtml(PrintWriter writer) { } writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " - + "shutdown time: %s", + ", stream is %dms old, last send %dms, last response %dms", summaryMetrics.streamAge(), summaryMetrics.timeSinceLastSend(), - summaryMetrics.timeSinceLastResponse(), - requestObserver.isClosed(), - summaryMetrics.shutdownTime().map(DateTime::toString).orElse(NOT_SHUTDOWN)); + summaryMetrics.timeSinceLastResponse()); + if (requestObserver.isClosed()) { + writer.append(", observer closed"); + } + summaryMetrics + .shutdownTime() + .ifPresent(dateTime -> writer.format(", shutdown at %s", dateTime)); + + @Nullable PhysicalStreamHandler currentHandler = currentPhysicalStreamForDebug.get(); + if (currentHandler != null) { + writer.format("
current physical stream: "); + appendPhysicalStream(writer, currentHandler); + } + + List closingStreamsSnapshot = + new ArrayList<>(closingPhysicalStreamsForDebug); + for (int i = 0; i < closingStreamsSnapshot.size(); ++i) { + writer.format("
closing physical stream #%d: ", i); + appendPhysicalStream(writer, closingStreamsSnapshot.get(i)); + } + } + + private void appendPhysicalStream( + PrintWriter writer, PhysicalStreamHandler physicalStreamHandler) { + physicalStreamHandler.appendHtml(writer); + StreamDebugMetrics.Snapshot summaryMetrics = + physicalStreamHandler.streamDebugMetrics.getSummaryMetrics(); + if (summaryMetrics.isClientClosed()) { + writer.write(" client closed"); + } + writer.format( + " started %dms ago, last send %dms, last response %dms\n", + summaryMetrics.streamAge(), + summaryMetrics.timeSinceLastSend(), + summaryMetrics.timeSinceLastResponse()); } /** @@ -375,7 +435,12 @@ public final void appendSummaryHtml(PrintWriter writer) { @Override public final synchronized void halfClose() { - // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. + if (clientClosed) { + logger.warn("Stream was previously closed."); + return; + } + // Synchronization of close and onCompleted necessary for correct retry logic in + // onPhysicalStreamCompleted. debugMetrics.recordHalfClose(); clientClosed = true; try { @@ -399,7 +464,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(debugMetrics.getStartTimeMs()); + return Instant.ofEpochMilli(debugMetrics.getStartTimeMs()); } @Override @@ -417,30 +482,22 @@ public final void shutdown() { isShutdown = true; debugMetrics.recordShutdown(); shutdownInternal(); + if (currentPhysicalStream == null && closingPhysicalStreams.isEmpty()) { + completeShutdown(); + } } } } - protected synchronized void shutdownInternal() {} - - /** Returns true if the stream was torn down and should not be restarted internally. */ - private synchronized boolean maybeTearDownStream(PhysicalStreamHandler doneStream) { - if (clientClosed && !doneStream.hasPendingRequests()) { - shutdown(); - } - - if (isShutdown) { - // Once we have background closing physicalStreams we will need to improve this to wait for - // all of the work of the logical stream to be complete. - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - executor.shutdownNow(); - return true; - } - - return false; + private void completeShutdown() { + logger.debug("Completing shutdown of stream after shutdown and all streams terminated."); + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); } + protected synchronized void shutdownInternal() {} + private class ResponseObserver implements StreamObserver { private final PhysicalStreamHandler handler; @@ -467,22 +524,67 @@ public void onCompleted() { } } - @SuppressWarnings("nullness") - private void clearPhysicalStreamForDebug() { - currentPhysicalStreamForDebug.set(null); + @SuppressWarnings("ReferenceEquality") + private void onHalfClosePhysicalStreamTimeout(PhysicalStreamHandler handler) { + synchronized (this) { + if (currentPhysicalStream != handler || clientClosed || isShutdown) { + return; + } + handler.streamDebugMetrics.recordHalfClose(); + closingPhysicalStreams.add(handler); + closingPhysicalStreamsForDebug.add(handler); + clearCurrentPhysicalStream(false); + try { + requestObserver.onCompleted(); + } catch (Exception e) { + logger.debug( + "Exception while half-closing handler, onPhysicalStreamCompletion will be called for the stream", + e); + } + } + startStream(); } + @SuppressWarnings("ReferenceEquality") private void onPhysicalStreamCompletion(Status status, PhysicalStreamHandler handler) { synchronized (this) { - if (currentPhysicalStream == handler) { - clearPhysicalStreamForDebug(); - currentPhysicalStream = null; + final boolean wasActiveStream = currentPhysicalStream == handler; + if (wasActiveStream) { + clearCurrentPhysicalStream(true); + } else { + checkState(closingPhysicalStreams.remove(handler)); + closingPhysicalStreamsForDebug.remove(handler); } + boolean doneHandlerHadRequests = handler.hasPendingRequests(); + handler.onDone(status); + if (currentPhysicalStream == null && closingPhysicalStreams.isEmpty()) { + if (clientClosed && !doneHandlerHadRequests && !isShutdown) { + shutdown(); + } + if (isShutdown) { + completeShutdown(); + return; + } + } + if (currentPhysicalStream != null) { + if (!clientClosed) { + // Don't bother attempting to flush the requests if the active stream is closed. + try { + onFlushPending(false); + } catch (WindmillStreamShutdownException e) { + logger.debug( + "Requests will be flushed by onPhysicalStreamCompletion of the current stream.", e); + } + } + return; + } + if (clientClosed && !doneHandlerHadRequests) { + // We didn't have any leftover requests and are closing so we skip restarting a stream. + return; + } + // We're not shutting down and we don't have an active stream, create one. } - handler.onDone(status); - if (maybeTearDownStream(handler)) { - return; - } + // Backoff on errors.; if (!status.isOk()) { try { @@ -498,6 +600,16 @@ private void onPhysicalStreamCompletion(Status status, PhysicalStreamHandler han startStream(); } + @SuppressWarnings("nullness") + private synchronized void clearCurrentPhysicalStream(boolean cancelHalfCloseFuture) { + currentPhysicalStream = null; + if (halfCloseFuture != null && cancelHalfCloseFuture) { + halfCloseFuture.cancel(false); + } + halfCloseFuture = null; + currentPhysicalStreamForDebug.set(null); + } + private void recordStreamRestart(Status status) { int currentRestartCount = debugMetrics.incrementAndGetRestarts(); if (status.isOk()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 51bc03e8e0e7..526b67890783 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -64,10 +64,6 @@ public interface WindmillStream { interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ void setBudget(GetWorkBudget newBudget); - - default void setBudget(long newItems, long newBytes) { - setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); - } } /** Interface for streaming GetDataRequests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 7a7b1a5cd27e..531d624a2e99 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -22,12 +22,14 @@ import com.google.auto.value.AutoValue; import java.io.PrintWriter; +import java.time.Duration; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; @@ -85,16 +87,19 @@ private GrpcCommitWorkStream( int logEveryNStreamFailures, JobHeader jobHeader, AtomicLong idGenerator, - int streamingRpcBatchLimit) { + int streamingRpcBatchLimit, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { super( LOG, - "CommitWorkStream", startCommitWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures, - backendWorkerToken); + backendWorkerToken, + halfClosePhysicalStreamAfter, + executor); this.idGenerator = idGenerator; this.jobHeader = jobHeader; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -110,7 +115,9 @@ static GrpcCommitWorkStream create( int logEveryNStreamFailures, JobHeader jobHeader, AtomicLong idGenerator, - int streamingRpcBatchLimit) { + int streamingRpcBatchLimit, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { return new GrpcCommitWorkStream( backendWorkerToken, startCommitWorkRpcFn, @@ -120,25 +127,33 @@ static GrpcCommitWorkStream create( logEveryNStreamFailures, jobHeader, idGenerator, - streamingRpcBatchLimit); + streamingRpcBatchLimit, + halfClosePhysicalStreamAfter, + executor); } @Override public void appendSpecificHtml(PrintWriter writer) { - writer.format("CommitWorkStream: %d pending", pending.size()); + writer.format("CommitWorkStream: %d pending ", pending.size()); } @Override - protected synchronized void onNewStream() throws WindmillStreamShutdownException { - trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + @SuppressWarnings("ReferenceEquality") + protected synchronized void onFlushPending(boolean isNewStream) + throws WindmillStreamShutdownException { + if (isNewStream) { + trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + } // Flush all pending requests that are no longer on active streams. try (Batcher resendBatcher = new Batcher()) { for (Map.Entry entry : pending.entrySet()) { CommitWorkPhysicalStreamHandler requestHandler = entry.getValue().handler; checkState(requestHandler != currentPhysicalStream); - // When we have streams closing in the background we should avoid retrying the requests - // active on those streams. - + if (requestHandler != null && closingPhysicalStreams.contains(requestHandler)) { + LOG.debug( + "Not resending request that is active on background half-closing physical stream."); + continue; + } long id = entry.getKey(); PendingRequest request = entry.getValue().request; if (!resendBatcher.canAccept(request.getBytes())) { @@ -169,6 +184,7 @@ protected synchronized void sendHealthCheck() throws WindmillStreamShutdownExcep private class CommitWorkPhysicalStreamHandler extends PhysicalStreamHandler { @Override + @SuppressWarnings("ReferenceEquality") public void onResponse(StreamingCommitResponse response) { CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); for (int i = 0; i < response.getRequestIdCount(); ++i) { @@ -206,6 +222,7 @@ public void onResponse(StreamingCommitResponse response) { } @Override + @SuppressWarnings("ReferenceEquality") public boolean hasPendingRequests() { return pending.entrySet().stream().anyMatch(e -> e.getValue().handler == this); } @@ -218,6 +235,7 @@ public void onDone(Status status) { } @Override + @SuppressWarnings("ReferenceEquality") public void appendHtml(PrintWriter writer) { writer.format( "CommitWorkStream: %d pending", diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 938ec1c693c7..2712bf1bd33d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -20,9 +20,11 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import java.io.PrintWriter; +import java.time.Duration; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import javax.annotation.concurrent.GuardedBy; @@ -98,16 +100,19 @@ private GrpcDirectGetWorkStream( HeartbeatSender heartbeatSender, GetDataClient getDataClient, WorkCommitter workCommitter, - WorkItemScheduler workItemScheduler) { + WorkItemScheduler workItemScheduler, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executorService) { super( LOG, - "GetWorkStream", startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures, - backendWorkerToken); + backendWorkerToken, + halfClosePhysicalStreamAfter, + executorService); this.requestHeader = requestHeader; this.workItemScheduler = workItemScheduler; this.heartbeatSender = heartbeatSender; @@ -138,7 +143,9 @@ static GrpcDirectGetWorkStream create( HeartbeatSender heartbeatSender, GetDataClient getDataClient, WorkCommitter workCommitter, - WorkItemScheduler workItemScheduler) { + WorkItemScheduler workItemScheduler, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { return new GrpcDirectGetWorkStream( backendWorkerToken, startGetWorkRpcFn, @@ -151,7 +158,9 @@ static GrpcDirectGetWorkStream create( heartbeatSender, getDataClient, workCommitter, - workItemScheduler); + workItemScheduler, + halfClosePhysicalStreamAfter, + executor); } private static Watermarks createWatermarks( @@ -230,7 +239,11 @@ protected PhysicalStreamHandler newResponseHandler() { } @Override - protected synchronized void onNewStream() throws WindmillStreamShutdownException { + protected synchronized void onFlushPending(boolean isNewStream) + throws WindmillStreamShutdownException { + if (!isNewStream) { + return; + } budgetTracker.reset(); GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); StreamingGetWorkRequest request = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 7de074122a3c..6d6dcd569e85 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -33,6 +33,7 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; @@ -112,16 +113,19 @@ private GrpcGetDataStream( AtomicLong idGenerator, int streamingRpcBatchLimit, boolean sendKeyedGetDataRequests, - Consumer> processHeartbeatResponses) { + Consumer> processHeartbeatResponses, + java.time.Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executorService) { super( LOG, - "GetDataStream", startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures, - backendWorkerToken); + backendWorkerToken, + halfClosePhysicalStreamAfter, + executorService); this.idGenerator = idGenerator; this.jobHeader = jobHeader; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -146,7 +150,9 @@ static GrpcGetDataStream create( AtomicLong idGenerator, int streamingRpcBatchLimit, boolean sendKeyedGetDataRequests, - Consumer> processHeartbeatResponses) { + Consumer> processHeartbeatResponses, + java.time.Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { return new GrpcGetDataStream( backendWorkerToken, startGetDataRpcFn, @@ -158,7 +164,9 @@ static GrpcGetDataStream create( idGenerator, streamingRpcBatchLimit, sendKeyedGetDataRequests, - processHeartbeatResponses); + processHeartbeatResponses, + halfClosePhysicalStreamAfter, + executor); } private static WindmillStreamShutdownException shutdownExceptionFor(QueuedBatch batch) { @@ -189,7 +197,7 @@ public void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException } if (!trySend(batch.asGetDataRequest())) { - // The stream broke before this call went through; onNewStream will retry the fetch. + // The stream broke before this call went through; onFlushPending will retry the fetch. LOG.debug("GetData stream broke before call started."); } } @@ -260,8 +268,11 @@ protected PhysicalStreamHandler newResponseHandler() { } @Override - protected synchronized void onNewStream() throws WindmillStreamShutdownException { - trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + protected synchronized void onFlushPending(boolean isNewStream) + throws WindmillStreamShutdownException { + if (isNewStream) { + trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + } while (!batches.isEmpty()) { QueuedBatch batch = checkNotNull(batches.peekFirst()); verify(!batch.isEmpty()); @@ -392,6 +403,12 @@ protected synchronized void shutdownInternal() { } currentGetDataStream.pending.clear(); } + for (PhysicalStreamHandler handler : closingPhysicalStreams) { + for (AppendableInputStream ais : ((GetDataPhysicalStreamHandler) handler).pending.values()) { + ais.cancel(); + } + ((GetDataPhysicalStreamHandler) handler).pending.clear(); + } batches.forEach( batch -> { batch.markFinalized(); @@ -402,7 +419,12 @@ protected synchronized void shutdownInternal() { @Override public void appendSpecificHtml(PrintWriter writer) { - writer.format("GetDataStream: %d queued batches", batchesDebugSizeSupplier.get()); + int batches = batchesDebugSizeSupplier.get(); + if (batches > 0) { + writer.format("GetDataStream: %d queued batches ", batches); + } else { + writer.append("GetDataStream: no queued batches "); + } } private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) @@ -476,10 +498,11 @@ private void queueRequestAndWait(QueuedRequest request) prevBatch.waitForSendOrFailNotification(); } trySendBatch(batch); - } else { - // Wait for this batch to be sent before parsing the response. - batch.waitForSendOrFailNotification(); + // Since the above send may not succeed, we fall through to block on sending or failure. } + + // Wait for this batch to be sent before parsing the response. + batch.waitForSendOrFailNotification(); } private synchronized void trySendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { @@ -494,8 +517,8 @@ private synchronized void trySendBatch(QueuedBatch batch) throws WindmillStreamS final @Nullable GetDataPhysicalStreamHandler currentGetDataPhysicalStream = (GetDataPhysicalStreamHandler) currentPhysicalStream; if (currentGetDataPhysicalStream == null) { - // Leave the batch finalized but in the batches queue. Finalized batches will be sent on the - // new stream in onNewStream. + // Leave the batch finalized but in the batches queue. Finalized batches will be sent on a + // new stream in onFlushPending. return; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index a1c758eac446..ae7ce85e13a8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -18,8 +18,10 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import java.io.PrintWriter; +import java.time.Duration; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; @@ -68,16 +70,19 @@ private GrpcGetWorkStream( Set> streamRegistry, int logEveryNStreamFailures, boolean requestBatchedGetWorkResponse, - WorkItemReceiver receiver) { + WorkItemReceiver receiver, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { super( LOG, - "GetWorkStream", startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures, - backendWorkerToken); + backendWorkerToken, + halfClosePhysicalStreamAfter, + executor); this.request = request; this.receiver = receiver; this.inflightMessages = new AtomicLong(); @@ -97,7 +102,9 @@ public static GrpcGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, boolean requestBatchedGetWorkResponse, - WorkItemReceiver receiver) { + WorkItemReceiver receiver, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executor) { return new GrpcGetWorkStream( backendWorkerToken, startGetWorkRpcFn, @@ -107,7 +114,9 @@ public static GrpcGetWorkStream create( streamRegistry, logEveryNStreamFailures, requestBatchedGetWorkResponse, - receiver); + receiver, + halfClosePhysicalStreamAfter, + executor); } private void sendRequestExtension(long moreItems, long moreBytes) { @@ -163,7 +172,11 @@ protected PhysicalStreamHandler newResponseHandler() { } @Override - protected synchronized void onNewStream() throws WindmillStreamShutdownException { + protected synchronized void onFlushPending(boolean isNewStream) + throws WindmillStreamShutdownException { + if (!isNewStream) { + return; + } inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); trySend( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 9b99b3bda909..a05e705b3f50 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -19,8 +19,10 @@ import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.PrintWriter; +import java.time.Duration; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -58,16 +60,19 @@ private GrpcGetWorkerMetadataStream( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - Consumer serverMappingConsumer) { + Consumer serverMappingConsumer, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executorService) { super( LOG, - "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures, - ""); + "", + halfClosePhysicalStreamAfter, + executorService); this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); this.serverMappingConsumer = serverMappingConsumer; this.latestResponse = WorkerMetadataResponse.getDefaultInstance(); @@ -82,7 +87,9 @@ public static GrpcGetWorkerMetadataStream create( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - Consumer serverMappingUpdater) { + Consumer serverMappingUpdater, + Duration halfClosePhysicalStreamAfter, + ScheduledExecutorService executorService) { return new GrpcGetWorkerMetadataStream( startGetWorkerMetadataRpcFn, backoff, @@ -90,7 +97,9 @@ public static GrpcGetWorkerMetadataStream create( streamRegistry, logEveryNStreamFailures, jobHeader, - serverMappingUpdater); + serverMappingUpdater, + halfClosePhysicalStreamAfter, + executorService); } /** @@ -141,8 +150,10 @@ public void appendHtml(PrintWriter writer) {} } @Override - protected void onNewStream() throws WindmillStreamShutdownException { - trySend(workerMetadataRequest); + protected void onFlushPending(boolean isNewStream) throws WindmillStreamShutdownException { + if (isNewStream) { + trySend(workerMetadataRequest); + } } @Override @@ -154,7 +165,7 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException { protected void appendSpecificHtml(PrintWriter writer) { synchronized (metadataLock) { writer.format( - "GetWorkerMetadataStream: job_header=[%s], current_metadata=[%s]", + "GetWorkerMetadataStream: job_header=[%s], current_metadata=[%s] ", workerMetadataRequest.getHeader(), latestResponse); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 1f261e59450a..244d2ad3fa14 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -27,6 +27,9 @@ import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; @@ -59,6 +62,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; import org.joda.time.Instant; @@ -69,8 +73,10 @@ @ThreadSafe @Internal public class GrpcWindmillStreamFactory implements StatusDataProvider { - private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; + private static final java.time.Duration + DEFAULT_DIRECT_STREAMING_RPC_PHYSICAL_STREAM_HALF_CLOSE_AFTER = + java.time.Duration.ofMinutes(3); private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -92,6 +98,8 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private final boolean sendKeyedGetDataRequests; private final boolean requestBatchedGetWorkResponse; private final Consumer> processHeartbeatResponses; + private final java.time.Duration directStreamingRpcPhysicalStreamHalfCloseAfter; + private final Supplier executorServiceSupplier; private GrpcWindmillStreamFactory( JobHeader jobHeader, @@ -101,7 +109,9 @@ private GrpcWindmillStreamFactory( boolean sendKeyedGetDataRequests, boolean requestBatchedGetWorkResponse, Consumer> processHeartbeatResponses, - Supplier maxBackOffSupplier) { + Supplier maxBackOffSupplier, + java.time.Duration directStreamingRpcPhysicalStreamHalfCloseAfter, + Supplier executorServiceSupplier) { this.jobHeader = jobHeader; this.logEveryNStreamFailures = logEveryNStreamFailures; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -119,6 +129,9 @@ private GrpcWindmillStreamFactory( this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse; this.processHeartbeatResponses = processHeartbeatResponses; this.streamIdGenerator = new AtomicLong(); + this.directStreamingRpcPhysicalStreamHalfCloseAfter = + directStreamingRpcPhysicalStreamHalfCloseAfter; + this.executorServiceSupplier = executorServiceSupplier; } /** @implNote Used for {@link AutoBuilder} {@link Builder} class, do not call directly. */ @@ -131,7 +144,9 @@ static GrpcWindmillStreamFactory create( boolean requestBatchedGetWorkResponse, Consumer> processHeartbeatResponses, Supplier maxBackOffSupplier, - int healthCheckIntervalMillis) { + int healthCheckIntervalMillis, + java.time.Duration directStreamingRpcPhysicalStreamHalfCloseAfter, + Supplier scheduledExecutorServiceSupplier) { GrpcWindmillStreamFactory streamFactory = new GrpcWindmillStreamFactory( jobHeader, @@ -141,7 +156,9 @@ static GrpcWindmillStreamFactory create( sendKeyedGetDataRequests, requestBatchedGetWorkResponse, processHeartbeatResponses, - maxBackOffSupplier); + maxBackOffSupplier, + directStreamingRpcPhysicalStreamHalfCloseAfter, + scheduledExecutorServiceSupplier); if (healthCheckIntervalMillis >= 0) { // Health checks are run on background daemon thread, which will only be cleaned up on JVM @@ -169,6 +186,7 @@ public void run() { * Returns a new {@link Builder} for {@link GrpcWindmillStreamFactory} with default values set for * the given {@link JobHeader}. */ + @SuppressWarnings("nullness") public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { return new AutoBuilder_GrpcWindmillStreamFactory_Builder() .setJobHeader(jobHeader) @@ -179,7 +197,10 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { .setHealthCheckIntervalMillis(NO_HEALTH_CHECKS) .setSendKeyedGetDataRequests(true) .setRequestBatchedGetWorkResponse(false) - .setProcessHeartbeatResponses(ignored -> {}); + .setProcessHeartbeatResponses(ignored -> {}) + .setDirectStreamingRpcPhysicalStreamHalfCloseAfter( + DEFAULT_DIRECT_STREAMING_RPC_PHYSICAL_STREAM_HALF_CLOSE_AFTER) + .setScheduledExecutorServiceSupplier(() -> null); } private static > T withDefaultDeadline(T stub) { @@ -201,6 +222,41 @@ private static void printSummaryHtmlForWorker( writer.write("
"); } + private ScheduledExecutorService executorForDispatchedStreams(String debugStreamTypeName) { + ScheduledExecutorService result = executorServiceSupplier.get(); + if (result != null) { + return result; + } + return Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(String.format("%s-WindmillStream-thread", debugStreamTypeName)) + .build()); + } + + private ScheduledExecutorService executorForDirectStreams( + String backendWorkerToken, String debugStreamTypeName) { + ScheduledExecutorService supplierResult = executorServiceSupplier.get(); + if (supplierResult != null) { + return supplierResult; + } + ScheduledThreadPoolExecutor result = + new ScheduledThreadPoolExecutor( + 0, + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat( + String.join( + "-", + debugStreamTypeName, + backendWorkerToken.substring(0, Math.min(10, backendWorkerToken.length())), + "WindmillStream", + "%d")) + .build()); + result.setKeepAliveTime(1, TimeUnit.MINUTES); + return result; + } + public GetWorkStream createGetWorkStream( CloudWindmillServiceV1Alpha1Stub stub, GetWorkRequest request, @@ -214,7 +270,9 @@ public GetWorkStream createGetWorkStream( streamRegistry, logEveryNStreamFailures, requestBatchedGetWorkResponse, - processWorkItem); + processWorkItem, + java.time.Duration.ZERO, + executorForDispatchedStreams("GetWork")); } public GetWorkStream createDirectGetWorkStream( @@ -226,7 +284,8 @@ public GetWorkStream createDirectGetWorkStream( WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), - responseObserver -> connection.currentStub().getWorkStream(responseObserver), + responseObserver -> + withDefaultDeadline(connection.currentStub()).getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -236,7 +295,9 @@ public GetWorkStream createDirectGetWorkStream( heartbeatSender, getDataClient, workCommitter, - workItemScheduler); + workItemScheduler, + directStreamingRpcPhysicalStreamHalfCloseAfter, + executorForDirectStreams(connection.backendWorkerToken(), "GetWork")); } public GetDataStream createGetDataStream(CloudWindmillServiceV1Alpha1Stub stub) { @@ -251,13 +312,16 @@ public GetDataStream createGetDataStream(CloudWindmillServiceV1Alpha1Stub stub) streamIdGenerator, streamingRpcBatchLimit, sendKeyedGetDataRequests, - processHeartbeatResponses); + processHeartbeatResponses, + java.time.Duration.ZERO, + executorForDispatchedStreams("GetWorkerMetadata")); } public GetDataStream createDirectGetDataStream(WindmillConnection connection) { return GrpcGetDataStream.create( connection.backendWorkerToken(), - responseObserver -> connection.currentStub().getDataStream(responseObserver), + responseObserver -> + withDefaultDeadline(connection.currentStub()).getDataStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, @@ -266,7 +330,9 @@ public GetDataStream createDirectGetDataStream(WindmillConnection connection) { streamIdGenerator, streamingRpcBatchLimit, sendKeyedGetDataRequests, - processHeartbeatResponses); + processHeartbeatResponses, + directStreamingRpcPhysicalStreamHalfCloseAfter, + executorForDirectStreams(connection.backendWorkerToken(), "GetData")); } public CommitWorkStream createCommitWorkStream(CloudWindmillServiceV1Alpha1Stub stub) { @@ -279,20 +345,25 @@ public CommitWorkStream createCommitWorkStream(CloudWindmillServiceV1Alpha1Stub logEveryNStreamFailures, jobHeader, streamIdGenerator, - streamingRpcBatchLimit); + streamingRpcBatchLimit, + java.time.Duration.ZERO, + executorForDispatchedStreams("CommitWork")); } public CommitWorkStream createDirectCommitWorkStream(WindmillConnection connection) { return GrpcCommitWorkStream.create( connection.backendWorkerToken(), - responseObserver -> connection.currentStub().commitWorkStream(responseObserver), + responseObserver -> + withDefaultDeadline(connection.currentStub()).commitWorkStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, jobHeader, streamIdGenerator, - streamingRpcBatchLimit); + streamingRpcBatchLimit, + directStreamingRpcPhysicalStreamHalfCloseAfter, + executorForDirectStreams(connection.backendWorkerToken(), "CommitWork")); } public GetWorkerMetadataStream createGetWorkerMetadataStream( @@ -305,7 +376,9 @@ public GetWorkerMetadataStream createGetWorkerMetadataStream( streamRegistry, logEveryNStreamFailures, jobHeader, - onNewWindmillEndpoints); + onNewWindmillEndpoints, + directStreamingRpcPhysicalStreamHalfCloseAfter, + executorForDispatchedStreams("GetWorkerMetadataStream")); } private StreamObserverFactory newStreamObserverFactory() { @@ -351,6 +424,11 @@ Builder setProcessHeartbeatResponses( Builder setRequestBatchedGetWorkResponse(boolean enabled); + Builder setDirectStreamingRpcPhysicalStreamHalfCloseAfter(java.time.Duration timeout); + + Builder setScheduledExecutorServiceSupplier( + Supplier scheduledExecutorServiceSupplier); + GrpcWindmillStreamFactory build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 59fd341fab4b..dd13d5b55930 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -253,7 +253,7 @@ public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedExcep Windmill.GetWorkResponse response = workToOffer.get(null); if (response == null) { try { - sleepMillis(500); + sleepMillis(100); } catch (InterruptedException e) { halfClose(); Thread.currentThread().interrupt(); @@ -515,9 +515,9 @@ public void clearCommitsReceived() { public ConcurrentHashMap> waitForDroppedCommits( int droppedCommits) { LOG.debug("waitForDroppedCommits: {}", droppedCommits); - int maxTries = 10; + int maxTries = 100; while (maxTries-- > 0 && droppedStreamingCommits.size() < droppedCommits) { - Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); } assertEquals(droppedCommits, droppedStreamingCommits.size()); return droppedStreamingCommits; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index c1696d8a70ab..a60535dfbd69 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -1285,7 +1285,7 @@ public void testKeyCommitTooLargeException() throws Exception { int maxTries = 10; while (--maxTries > 0) { worker.reportPeriodicWorkerUpdatesForTest(); - Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); } // We should see an exception reported for the large commit but not the small one. @@ -1489,9 +1489,9 @@ public void testExceptions() throws Exception { server.waitForEmptyWorkQueue(); // Wait until the worker has given up. - int maxTries = 10; + int maxTries = 100; while (maxTries-- > 0 && !worker.workExecutorIsEmpty()) { - Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); } assertTrue(worker.workExecutorIsEmpty()); @@ -1499,7 +1499,7 @@ public void testExceptions() throws Exception { maxTries = 10; while (maxTries-- > 0) { worker.reportPeriodicWorkerUpdatesForTest(); - Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); } // We should see our update only one time with the exceptions we are expecting. @@ -3520,7 +3520,7 @@ public void testActiveWorkFailure() throws Exception { // Release the blocked calls. BlockingFn.blocker().countDown(); Map commits = - server.waitForAndGetCommitsWithTimeout(2, Duration.standardSeconds((5))); + server.waitForAndGetCommitsWithTimeout(1, Duration.standardSeconds((5))); assertEquals(1, commits.size()); worker.stop(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index 92c081591c73..80c39e770c3e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -21,12 +21,14 @@ import static org.junit.Assert.assertThrows; import java.io.PrintWriter; +import java.time.temporal.ChronoUnit; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -59,7 +61,12 @@ public void setUp() { private TestStream newStream( Function, StreamObserver> clientFactory) { - return new TestStream(clientFactory, streamRegistry, streamObserverFactory); + return new TestStream( + clientFactory, + streamRegistry, + streamObserverFactory, + Duration.ZERO, + Executors.newScheduledThreadPool(0)); } @Test @@ -140,21 +147,25 @@ private static class TestStream extends AbstractWindmillStream private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStreamTest.class); private final AtomicInteger numStarts = new AtomicInteger(); + private final AtomicInteger numFlushPending = new AtomicInteger(); private final AtomicInteger numHealthChecks = new AtomicInteger(); private TestStream( Function, StreamObserver> clientFactory, Set> streamRegistry, - StreamObserverFactory streamObserverFactory) { + StreamObserverFactory streamObserverFactory, + Duration halfCloseAfterTimeout, + ScheduledExecutorService executorService) { super( LoggerFactory.getLogger(AbstractWindmillStreamTest.class), - "Test", clientFactory, FluentBackoff.DEFAULT.backoff(), streamObserverFactory, streamRegistry, 1, - "Test"); + "Test", + java.time.Duration.of(halfCloseAfterTimeout.getMillis(), ChronoUnit.MILLIS), + executorService); } @Override @@ -178,8 +189,11 @@ public void appendHtml(PrintWriter writer) {} } @Override - protected void onNewStream() { - numStarts.incrementAndGet(); + protected void onFlushPending(boolean isNewStream) { + if (isNewStream) { + numStarts.incrementAndGet(); + } + numFlushPending.incrementAndGet(); } private void testSend() throws WindmillStreamShutdownException { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/TriggeredScheduledExecutorService.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/TriggeredScheduledExecutorService.java new file mode 100644 index 000000000000..a43da93b3680 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/TriggeredScheduledExecutorService.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.time.Duration; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Delayed; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import javax.annotation.Nullable; + +public class TriggeredScheduledExecutorService extends ThreadPoolExecutor + implements ScheduledExecutorService { + private final BlockingQueue futures = new LinkedBlockingQueue<>(); + + public TriggeredScheduledExecutorService() { + super(0, 100, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); + } + + public boolean unblockNextFuture() throws InterruptedException { + @Nullable FakeScheduledFuture f = futures.take(); + if (f == null) { + return false; + } + f.triggerRun(); + return true; + } + + @Override + public ScheduledFuture schedule(Runnable runnable, long l, TimeUnit timeUnit) { + FakeScheduledFuture f = + new FakeScheduledFuture(runnable, Duration.ofMillis(timeUnit.toMillis(l))); + try { + futures.put(f); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return f; + } + + @Override + public ScheduledFuture schedule(Callable callable, long l, TimeUnit timeUnit) { + throw new UnsupportedOperationException("not supported yet"); + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable runnable, long l, long l1, TimeUnit timeUnit) { + throw new UnsupportedOperationException("not supported yet"); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable runnable, long l, long l1, TimeUnit timeUnit) { + throw new UnsupportedOperationException("not supported yet"); + } + + private class FakeScheduledFuture implements ScheduledFuture { + private final Runnable r; + private final Duration delay; + private transient boolean cancelled; + private final CompletableFuture delegateFuture = new CompletableFuture<>(); + + private FakeScheduledFuture(Runnable r, Duration delay) { + this.r = r; + this.delay = delay; + } + + void triggerRun() { + TriggeredScheduledExecutorService.this.execute( + () -> { + try { + r.run(); + delegateFuture.complete(null); + } catch (RuntimeException e) { + delegateFuture.completeExceptionally(e); + } + }); + } + + @Override + public long getDelay(TimeUnit timeUnit) { + return timeUnit.convert(delay.toMillis(), TimeUnit.MILLISECONDS); + } + + @Override + public int compareTo(Delayed delayed) { + return 0; + } + + @Override + public boolean cancel(boolean b) { + cancelled = true; + return true; + } + + @Override + public boolean isCancelled() { + return cancelled; + } + + @Override + public boolean isDone() { + return delegateFuture.isDone(); + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return delegateFuture.get(); + } + + @Override + public Void get(long l, TimeUnit timeUnit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegateFuture.get(l, timeUnit); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java index 85c3c71663f1..19f8c1578b46 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java @@ -19,6 +19,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; @@ -32,12 +33,31 @@ class FakeWindmillGrpcService private final ErrorCollector errorCollector; @GuardedBy("this") - private boolean failOnNewStreams = false; + private boolean noMoreStreamsExpected = false; + + @GuardedBy("this") + private int failedStreamConnectsRemaining = 0; public FakeWindmillGrpcService(ErrorCollector errorCollector) { this.errorCollector = errorCollector; } + @SuppressWarnings("BusyWait") + public void waitForFailedConnectAttempts() throws InterruptedException { + while (true) { + Thread.sleep(2); + synchronized (this) { + if (failedStreamConnectsRemaining <= 0) { + break; + } + } + } + } + + public synchronized void setFailedStreamConnectsRemaining(int failedStreamConnectsRemaining) { + this.failedStreamConnectsRemaining = failedStreamConnectsRemaining; + } + public static class StreamInfo { public StreamInfo(StreamObserver responseObserver) { this.responseObserver = responseObserver; @@ -63,6 +83,17 @@ public StreamInfoObserver( @Override public void onNext(RequestT request) { + if (streamInfo.onDone.isDone()) { + try { + if (streamInfo.onDone.get() == null) { + throw new IllegalStateException("Stream already half-closed."); + } else { + throw new IllegalStateException("Stream already closed with error."); + } + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } errorCollector.checkThat(streamInfo.requests.add(request), Matchers.is(true)); } @@ -89,7 +120,11 @@ public StreamObserver commitWorkStream( StreamObserver responseObserver) { CommitStreamInfo info = new CommitStreamInfo(responseObserver); synchronized (this) { - errorCollector.checkThat(failOnNewStreams, Matchers.is(false)); + errorCollector.checkThat(noMoreStreamsExpected, Matchers.is(false)); + if (failedStreamConnectsRemaining-- > 0) { + throw new RuntimeException( + "Injected connection error, remaining failures: " + failedStreamConnectsRemaining); + } errorCollector.checkThat(commitStreams.offer(info), Matchers.is(true)); } return new StreamInfoObserver<>(info, errorCollector); @@ -100,7 +135,7 @@ public CommitStreamInfo waitForConnectedCommitStream() throws InterruptedExcepti } public synchronized void expectNoMoreStreams() { - failOnNewStreams = true; + noMoreStreamsExpected = true; errorCollector.checkThat(commitStreams.isEmpty(), Matchers.is(true)); errorCollector.checkThat(getDataStreams.isEmpty(), Matchers.is(true)); } @@ -117,7 +152,11 @@ public StreamObserver getDataStream( StreamObserver responseObserver) { GetDataStreamInfo info = new GetDataStreamInfo(responseObserver); synchronized (this) { - errorCollector.checkThat(failOnNewStreams, Matchers.is(false)); + errorCollector.checkThat(noMoreStreamsExpected, Matchers.is(false)); + if (failedStreamConnectsRemaining-- > 0) { + throw new RuntimeException( + "Injected connection error, remaining failures: " + failedStreamConnectsRemaining); + } errorCollector.checkThat(getDataStreams.offer(info), Matchers.is(true)); } return new StreamInfoObserver<>(info, errorCollector); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 195e13e84e26..e9fd55fa5668 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -21,21 +21,28 @@ import static org.hamcrest.Matchers.*; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.time.Duration; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.TriggeredScheduledExecutorService; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; @@ -120,6 +127,32 @@ private GrpcCommitWorkStream createCommitWorkStream() { return commitWorkStream; } + private GrpcCommitWorkStream createCommitWorkStreamWithPhysicalStreamHandover( + ScheduledExecutorService executor) { + GrpcCommitWorkStream commitWorkStream = + (GrpcCommitWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setDirectStreamingRpcPhysicalStreamHalfCloseAfter(Duration.ofMinutes(1)) + .setScheduledExecutorServiceSupplier( + new Supplier() { + private final AtomicBoolean vended = new AtomicBoolean(); + + @Override + public ScheduledExecutorService get() { + assertFalse(vended.getAndSet(true)); + return executor; + } + }) + .build() + .createDirectCommitWorkStream( + WindmillConnection.builder() + .setStubSupplier( + () -> CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build()); + commitWorkStream.start(); + return commitWorkStream; + } + @Test public void testShutdown_abortsActiveCommits() throws InterruptedException, ExecutionException { int numCommits = 5; @@ -459,6 +492,647 @@ public void testSend_notCalledAfterShutdown_Multichunk() assertThat(streamInfo.requests).isEmpty(); } + private Windmill.WorkItemCommitRequest createTestCommit(int id) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(id) + .setWorkToken(id * 100L) + .setCacheToken(id * 1000L) + .build(); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams() throws Exception { + // A special executor that allows triggering scheduled futures (of which the handover is the + // only such future). + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // Send a request where the response is captured in a future. + Windmill.WorkItemCommitRequest workItemCommitRequest = createTestCommit(1); + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(workItemCommitRequest); + + // Trigger a new stream to be created by forcing the scheduled halfCloseFuture scheduled within + // AbstractWindmillStream to run. + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + // Previous stream client should be half-closed. + assertNull(streamInfo.onDone.get()); + + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).build()); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Complete server-side half-close of first stream. No new + // stream should be created since the current stream is active. + streamInfo.responseObserver.onCompleted(); + + // Close the stream, the open stream should be client half-closed + // but logical remains not terminated. + commitWorkStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + assertFalse(commitWorkStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + // Complete half-closing from the server and verify shutdown completes. + streamInfo2.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_oldStreamFails() throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + Windmill.WorkItemCommitRequest workItemCommitRequest = createTestCommit(1); + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(workItemCommitRequest); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + // Previous stream client should be half-closed. + assertNull(streamInfo.onDone.get()); + + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).build()); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Complete first stream with an error. No new + // stream should be created since the current stream is active. The request should have an + // error and the request should be retried on the new stream. + streamInfo.responseObserver.onError(new RuntimeException("test error")); + Windmill.StreamingCommitWorkRequest request3 = streamInfo2.requests.take(); + assertThat(request3.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest3 = + Windmill.WorkItemCommitRequest.parseFrom( + request3.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest3).isEqualTo(workItemCommitRequest); + + // Close the stream, the open stream should be client half-closed + // but logical remains not terminated. + commitWorkStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + assertFalse(commitWorkStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Complete half-closing from the server and verify shutdown completes. + streamInfo2.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_newStreamFailsWhileEmpty() + throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + Windmill.WorkItemCommitRequest workItemCommitRequest = createTestCommit(1); + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(workItemCommitRequest); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + // Before stream 1 is finished simulate stream 2 failing. + streamInfo2.responseObserver.onError(new IOException("stream 2 failed")); + // A new stream should be created and handle new requests. + FakeWindmillGrpcService.CommitStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + Windmill.StreamingCommitWorkRequest request2 = streamInfo3.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + streamInfo3.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).build()); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Close the stream. + commitWorkStream.halfClose(); + assertNull(streamInfo.onDone.get()); + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_newStreamFailsWithRequests() + throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + Windmill.WorkItemCommitRequest workItemCommitRequest = createTestCommit(1); + CompletableFuture commitStatusFuture = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest, commitStatusFuture::complete)); + } + + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertThat(request.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest).isEqualTo(workItemCommitRequest); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + // Before stream 1 is finished simulate stream 2 failing. + streamInfo2.responseObserver.onError(new IOException("stream 2 failed")); + // A new stream should be created and receive the pending requests from stream2 but not the + // request from stream1. + FakeWindmillGrpcService.CommitStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + Windmill.StreamingCommitWorkRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest3 = + Windmill.WorkItemCommitRequest.parseFrom( + request3.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest3).isEqualTo(workItemCommitRequest2); + + streamInfo3.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).build()); + + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + assertThat(commitStatusFuture.get()).isEqualTo(Windmill.CommitStatus.OK); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Close the stream. + commitWorkStream.halfClose(); + assertNull(streamInfo.onDone.get()); + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers() throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo1 = waitForConnectionAndConsumeHeader(); + + // Commit request 1 on stream 1 + Windmill.WorkItemCommitRequest workItemCommitRequest1 = createTestCommit(1); + CompletableFuture commitStatusFuture1 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest1, commitStatusFuture1::complete)); + } + + Windmill.StreamingCommitWorkRequest request1 = streamInfo1.requests.take(); + assertThat(request1.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest1 = + Windmill.WorkItemCommitRequest.parseFrom( + request1.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest1).isEqualTo(workItemCommitRequest1); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo1.onDone.get()); + + // Commit request 2 on stream 2 + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Commit request 3 on stream 3 + Windmill.WorkItemCommitRequest workItemCommitRequest3 = createTestCommit(3); + CompletableFuture commitStatusFuture3 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest3, commitStatusFuture3::complete)); + } + + Windmill.StreamingCommitWorkRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest3 = + Windmill.WorkItemCommitRequest.parseFrom( + request3.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest3).isEqualTo(workItemCommitRequest3); + + // Respond to all requests + streamInfo1.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).build()); + streamInfo3.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(3).build()); + + assertThat(commitStatusFuture1.get()).isEqualTo(Windmill.CommitStatus.OK); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.OK); + assertThat(commitStatusFuture3.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Close the stream + commitWorkStream.halfClose(); + assertNull(streamInfo3.onDone.get()); + + // Verify no more streams + fakeService.expectNoMoreStreams(); + streamInfo1.responseObserver.onCompleted(); + streamInfo2.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_oldStreamFailsWhileNewStreamInBackoff() + throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo1 = waitForConnectionAndConsumeHeader(); + + // Commit request 1 on stream 1 + Windmill.WorkItemCommitRequest workItemCommitRequest1 = createTestCommit(1); + CompletableFuture commitStatusFuture1 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest1, commitStatusFuture1::complete)); + } + + Windmill.StreamingCommitWorkRequest request1 = streamInfo1.requests.take(); + assertThat(request1.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest1 = + Windmill.WorkItemCommitRequest.parseFrom( + request1.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest1).isEqualTo(workItemCommitRequest1); + + // Trigger handover but fail new connections + assertTrue(triggeredExecutor.unblockNextFuture()); + fakeService.setFailedStreamConnectsRemaining(1); + fakeService.waitForFailedConnectAttempts(); + assertNull(streamInfo1.onDone.get()); + + // Fail first stream + streamInfo1.responseObserver.onError(new RuntimeException("test error")); + + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest1); + + // Respond to the request + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).build()); + assertThat(commitStatusFuture1.get()).isEqualTo(Windmill.CommitStatus.OK); + + // Close the stream + commitWorkStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + + streamInfo2.responseObserver.onCompleted(); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_shutdown() + throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo1 = waitForConnectionAndConsumeHeader(); + + // Commit request 1 on stream 1 + Windmill.WorkItemCommitRequest workItemCommitRequest1 = createTestCommit(1); + CompletableFuture commitStatusFuture1 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest1, commitStatusFuture1::complete)); + } + + Windmill.StreamingCommitWorkRequest request1 = streamInfo1.requests.take(); + assertThat(request1.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest1 = + Windmill.WorkItemCommitRequest.parseFrom( + request1.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest1).isEqualTo(workItemCommitRequest1); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo1.onDone.get()); + + // Commit request 2 on stream 2 + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Commit request 3 on stream 3 + Windmill.WorkItemCommitRequest workItemCommitRequest3 = createTestCommit(3); + CompletableFuture commitStatusFuture3 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest3, commitStatusFuture3::complete)); + } + + Windmill.StreamingCommitWorkRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest3 = + Windmill.WorkItemCommitRequest.parseFrom( + request3.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest3).isEqualTo(workItemCommitRequest3); + + // Shutdown while there are active streams and verify it isn't completed until all the streams + // are done. + fakeService.expectNoMoreStreams(); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.SECONDS)); + commitWorkStream.shutdown(); + assertThat(commitStatusFuture1.isDone()).isTrue(); + assertThat(commitStatusFuture2.isDone()).isTrue(); + assertThat(commitStatusFuture3.isDone()).isTrue(); + assertFalse(commitWorkStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo3.responseObserver.onCompleted(); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo1.responseObserver.onCompleted(); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo2.responseObserver.onError(new RuntimeException("test")); + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testCommitWorkItem_multiplePhysicalStreams_multipleHandovers_halfClose() + throws Exception { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcCommitWorkStream commitWorkStream = + createCommitWorkStreamWithPhysicalStreamHandover(triggeredExecutor); + commitWorkStream.start(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo1 = waitForConnectionAndConsumeHeader(); + + // Commit request 1 on stream 1 + Windmill.WorkItemCommitRequest workItemCommitRequest1 = createTestCommit(1); + CompletableFuture commitStatusFuture1 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest1, commitStatusFuture1::complete)); + } + + Windmill.StreamingCommitWorkRequest request1 = streamInfo1.requests.take(); + assertThat(request1.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest1 = + Windmill.WorkItemCommitRequest.parseFrom( + request1.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest1).isEqualTo(workItemCommitRequest1); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo1.onDone.get()); + + // Commit request 2 on stream 2 + Windmill.WorkItemCommitRequest workItemCommitRequest2 = createTestCommit(2); + CompletableFuture commitStatusFuture2 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest2, commitStatusFuture2::complete)); + } + + Windmill.StreamingCommitWorkRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest2 = + Windmill.WorkItemCommitRequest.parseFrom( + request2.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest2).isEqualTo(workItemCommitRequest2); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.CommitStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Commit request 3 on stream 3 + Windmill.WorkItemCommitRequest workItemCommitRequest3 = createTestCommit(3); + CompletableFuture commitStatusFuture3 = new CompletableFuture<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, workItemCommitRequest3, commitStatusFuture3::complete)); + } + + Windmill.StreamingCommitWorkRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getCommitChunkList()).hasSize(1); + Windmill.WorkItemCommitRequest parsedRequest3 = + Windmill.WorkItemCommitRequest.parseFrom( + request3.getCommitChunk(0).getSerializedWorkItemCommit()); + assertThat(parsedRequest3).isEqualTo(workItemCommitRequest3); + + // Shutdown while there are active streams and verify it isn't completed until all the streams + // are done. + fakeService.expectNoMoreStreams(); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.SECONDS)); + commitWorkStream.halfClose(); + + assertFalse(commitWorkStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + assertThat(streamInfo3.onDone.get()).isNull(); + + assertThat(commitStatusFuture1.isDone()).isFalse(); + assertThat(commitStatusFuture2.isDone()).isFalse(); + assertThat(commitStatusFuture3.isDone()).isFalse(); + + streamInfo3.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(3).build()); + streamInfo3.responseObserver.onCompleted(); + assertThat(commitStatusFuture3.get()).isEqualTo(Windmill.CommitStatus.OK); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + + streamInfo1.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder() + .addRequestId(1) + .addStatus(Windmill.CommitStatus.ABORTED) + .build()); + streamInfo1.responseObserver.onCompleted(); + assertThat(commitStatusFuture1.get()).isEqualTo(Windmill.CommitStatus.ABORTED); + assertFalse(commitWorkStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + + streamInfo2.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder() + .addRequestId(2) + .addStatus(Windmill.CommitStatus.ALREADY_IN_COMMIT) + .build()); + streamInfo2.responseObserver.onCompleted(); + assertThat(commitStatusFuture2.get()).isEqualTo(Windmill.CommitStatus.ALREADY_IN_COMMIT); + + assertTrue(commitWorkStream.awaitTermination(10, TimeUnit.SECONDS)); + } + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { try { FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index e954f2cc7105..4f584022c8a5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -19,26 +19,42 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.io.IOException; +import java.time.Duration; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import java.util.logging.Level; +import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; +import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.TriggeredScheduledExecutorService; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.CallOptions; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Channel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientCall; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ClientInterceptor; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.MethodDescriptor; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder; @@ -86,6 +102,8 @@ public void setUp() throws IOException { inProcessChannel = grpcCleanup.register( InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + Logger.getLogger(GrpcGetDataStream.class.getName()).setLevel(Level.ALL); + Logger.getLogger(AbstractMethodError.class.getName()).setLevel(Level.ALL); } @After @@ -105,20 +123,39 @@ private GrpcGetDataStream createGetDataStream() { return getDataStream; } + private GrpcGetDataStream createGetDataStreamWithPhysicalStreamHandover( + Duration handover, @Nullable ScheduledExecutorService executor) { + GrpcGetDataStream getDataStream = + (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setDirectStreamingRpcPhysicalStreamHalfCloseAfter(handover) + .setScheduledExecutorServiceSupplier( + new Supplier() { + private final AtomicBoolean vended = new AtomicBoolean(); + + @Override + public ScheduledExecutorService get() { + assertFalse(vended.getAndSet(true)); + return executor; + } + }) + .build() + .createDirectGetDataStream( + WindmillConnection.builder() + .setStubSupplier( + () -> CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build()); + getDataStream.start(); + return getDataStream; + } + @Test public void testRequestKeyedData() throws InterruptedException { GrpcGetDataStream getDataStream = createGetDataStream(); FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); // These will block until they are successfully sent. - Windmill.KeyedGetDataRequest keyedGetDataRequest = - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(ByteString.EMPTY) - .setShardingKey(1) - .setCacheToken(1) - .setWorkToken(1) - .build(); - + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); CompletableFuture sendFuture = CompletableFuture.supplyAsync( () -> { @@ -133,12 +170,7 @@ public void testRequestKeyedData() throws InterruptedException { assertThat(request.getRequestIdList()).containsExactly(1L); assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); - Windmill.KeyedGetDataResponse keyedGetDataResponse = - Windmill.KeyedGetDataResponse.newBuilder() - .setShardingKey(1) - .setKey(ByteString.EMPTY) - .build(); - + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); streamInfo.responseObserver.onNext( Windmill.StreamingGetDataResponse.newBuilder() .addRequestId(1) @@ -171,14 +203,7 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow } } try { - getDataStream.requestKeyedData( - "computationId", - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(ByteString.EMPTY) - .setShardingKey(i) - .setCacheToken(i) - .setWorkToken(i) - .build()); + getDataStream.requestKeyedData("computationId", createTestRequest(i)); } catch (WindmillStreamShutdownException e) { throw new RuntimeException(e); } @@ -290,14 +315,766 @@ public void testRequestKeyedData_reconnectOnStreamErrorAfterHalfClose() getDataStream.halfClose(); assertNull(streamInfo.onDone.get()); - // Simulate an error on the grpc stream, this should trigger an error on all - // existing requests but no new connection since we half-closed and nothing left after - // responding with errors. - fakeService.expectNoMoreStreams(); + // Simulate an error on the grpc stream, this should trigger retrying the requests on a new + // stream + // which is half-closed. streamInfo.responseObserver.onError(new IOException("test error")); - assertThrows(RuntimeException.class, sendFuture::join); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request2.getStateRequest(0).getRequests(0)); + assertNull(streamInfo2.onDone.get()); + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + + // Sending an error this time shouldn't result in a new stream since there were no requests. + fakeService.expectNoMoreStreams(); + streamInfo2.responseObserver.onError(new IOException("test error")); + + getDataStream.awaitTermination(60, TimeUnit.MINUTES); + } + + private Windmill.KeyedGetDataRequest createTestRequest(long id) { + return Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(id) + .setCacheToken(id * 100) + .setWorkToken(id * 1000) + .build(); + } + + private Windmill.KeyedGetDataResponse createTestResponse(long id) { + return Windmill.KeyedGetDataResponse.newBuilder() + .setShardingKey(id) + .setKey(ByteString.EMPTY) + .build(); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + // Previous stream client should be half-closed. + assertNull(streamInfo.onDone.get()); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + + // Complete server-side half-close of first stream. No new + // stream should be created since the current stream is active. + streamInfo.responseObserver.onCompleted(); + + // Close the stream, the open stream should be client half-closed + // but logical remains not terminated. + getDataStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + assertFalse(getDataStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + // Complete half-closing from the server and verify shutdown completes. + streamInfo2.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_oldStreamFails() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + // Previous stream client should be half-closed. + assertNull(streamInfo.onDone.get()); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + + // Complete first stream with an error. No new + // stream should be created since the current stream is active. The request should have an + // error and the request should be retried on the new stream. + streamInfo.responseObserver.onError(new RuntimeException("test error")); + Windmill.StreamingGetDataRequest request3 = streamInfo2.requests.take(); + assertThat(request3.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request3.getStateRequest(0).getRequests(0)); + + // Close the stream, the open stream should be client half-closed + // but logical remains not terminated. + getDataStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + assertFalse(getDataStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + + // Complete half-closing from the server and verify shutdown completes. + streamInfo2.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_newStreamFailsWhileEmpty() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + // Before stream 1 is finished simulate stream 2 failing. + streamInfo2.responseObserver.onError(new IOException("stream 2 failed")); + // A new stream should be created and handle new requests. + FakeWindmillGrpcService.GetDataStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo3.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo3.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + + // Close the stream. + getDataStream.halfClose(); + assertNull(streamInfo.onDone.get()); + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_newStreamFailsWithRequests() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // A new stream should be created due to handover. + assertTrue(triggeredExecutor.unblockNextFuture()); + + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + // Before stream 1 is finished simulate stream 2 failing. + streamInfo2.responseObserver.onError(new IOException("stream 2 failed")); + // A new stream should be created and receive the pending requests from stream2 but not the + // request from stream1. + FakeWindmillGrpcService.GetDataStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + Windmill.StreamingGetDataRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request3.getStateRequest(0).getRequests(0)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo3.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + + // Close the stream. + getDataStream.halfClose(); + assertNull(streamInfo.onDone.get()); + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_multipleHandovers_allResponsesReceived() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // Request 1, Stream 1 + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = createTestRequest(1); + CompletableFuture sendFuture1 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest1); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request1 = streamInfo.requests.take(); + assertThat(request1.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest1, request1.getStateRequest(0).getRequests(0)); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + // Request 2, Stream 2 + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Request 3, Stream 3 + Windmill.KeyedGetDataRequest keyedGetDataRequest3 = createTestRequest(3); + CompletableFuture sendFuture3 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest3); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getRequestIdList()).containsExactly(3L); + assertEquals(keyedGetDataRequest3, request3.getStateRequest(0).getRequests(0)); + + // Respond to all requests + Windmill.KeyedGetDataResponse keyedGetDataResponse1 = createTestResponse(1); + streamInfo.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse1.toByteString()) + .build()); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + streamInfo2.responseObserver.onCompleted(); + + Windmill.KeyedGetDataResponse keyedGetDataResponse3 = createTestResponse(3); + streamInfo3.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(3) + .addSerializedResponse(keyedGetDataResponse3.toByteString()) + .build()); + + assertThat(sendFuture1.join()).isEqualTo(keyedGetDataResponse1); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + assertThat(sendFuture3.join()).isEqualTo(keyedGetDataResponse3); + + // Close the stream. + getDataStream.halfClose(); + assertNull(streamInfo3.onDone.get()); + + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_oldStreamFailsWhileNewStreamInBackoff() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // A new stream should be created due to handover. However we configure the server to have + // errors. + assertTrue(triggeredExecutor.unblockNextFuture()); + fakeService.setFailedStreamConnectsRemaining(1); + fakeService.waitForFailedConnectAttempts(); + // Previous stream client should be half-closed. + assertNull(streamInfo.onDone.get()); + // Complete first stream with an error. No new + // stream should be created since the current stream is being created or created. The request + // should have an + // error and the request should be retried on the new stream. + streamInfo.responseObserver.onError(new RuntimeException("test error")); + + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + fakeService.expectNoMoreStreams(); + + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request2.getStateRequest(0).getRequests(0)); + + // Close the stream, the open stream should be client half-closed + // but logical remains not terminated. + getDataStream.halfClose(); + assertNull(streamInfo2.onDone.get()); + assertFalse(getDataStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); + + // Complete half-closing from the server and verify shutdown completes. + streamInfo2.responseObserver.onCompleted(); + + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_multipleHandovers_shutdown() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // Request 1, Stream 1 + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = createTestRequest(1); + CompletableFuture sendFuture1 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest1); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request1 = streamInfo.requests.take(); + assertThat(request1.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest1, request1.getStateRequest(0).getRequests(0)); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + // Request 2, Stream 2 + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Request 3, Stream 3 + Windmill.KeyedGetDataRequest keyedGetDataRequest3 = createTestRequest(3); + CompletableFuture sendFuture3 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest3); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getRequestIdList()).containsExactly(3L); + assertEquals(keyedGetDataRequest3, request3.getStateRequest(0).getRequests(0)); + + // Shutdown while there are active streams and verify it isn't completed until all the streams + // are done. + fakeService.expectNoMoreStreams(); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.SECONDS)); + getDataStream.shutdown(); + assertThrows("WindmillStreamShutdownException", CompletionException.class, sendFuture1::join); + assertThrows("WindmillStreamShutdownException", CompletionException.class, sendFuture2::join); + assertThrows("WindmillStreamShutdownException", CompletionException.class, sendFuture3::join); + assertFalse(getDataStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo3.responseObserver.onCompleted(); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo.responseObserver.onCompleted(); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + streamInfo2.responseObserver.onError(new RuntimeException("test")); + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_multiplePhysicalStreams_multipleHandovers_halfClose() + throws InterruptedException, ExecutionException { + TriggeredScheduledExecutorService triggeredExecutor = new TriggeredScheduledExecutorService(); + GrpcGetDataStream getDataStream = + createGetDataStreamWithPhysicalStreamHandover(Duration.ofSeconds(60), triggeredExecutor); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // Request 1, Stream 1 + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = createTestRequest(1); + CompletableFuture sendFuture1 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest1); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request1 = streamInfo.requests.take(); + assertThat(request1.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest1, request1.getStateRequest(0).getRequests(0)); + + // Trigger handover 1 + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo2 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo.onDone.get()); + + // Request 2, Stream 2 + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = createTestRequest(2); + CompletableFuture sendFuture2 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest2); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request2 = streamInfo2.requests.take(); + assertThat(request2.getRequestIdList()).containsExactly(2L); + assertEquals(keyedGetDataRequest2, request2.getStateRequest(0).getRequests(0)); + + // Trigger handover 2 before streamInfo2 completes + assertTrue(triggeredExecutor.unblockNextFuture()); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo3 = waitForConnectionAndConsumeHeader(); + assertNull(streamInfo2.onDone.get()); + + // Request 3, Stream 3 + Windmill.KeyedGetDataRequest keyedGetDataRequest3 = createTestRequest(3); + CompletableFuture sendFuture3 = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest3); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + Windmill.StreamingGetDataRequest request3 = streamInfo3.requests.take(); + assertThat(request3.getRequestIdList()).containsExactly(3L); + assertEquals(keyedGetDataRequest3, request3.getStateRequest(0).getRequests(0)); + + // Half-close while there are active streams and verify it isn't completed until all the streams + // are done. Streams with requests should have requests resent. + fakeService.expectNoMoreStreams(); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.SECONDS)); + getDataStream.halfClose(); + assertNull(streamInfo.onDone.get()); + assertFalse(getDataStream.awaitTermination(10, TimeUnit.MILLISECONDS)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse3 = createTestResponse(3); + streamInfo3.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(3) + .addSerializedResponse(keyedGetDataResponse3.toByteString()) + .build()); + assertThat(sendFuture3.join()).isEqualTo(keyedGetDataResponse3); + + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + Windmill.KeyedGetDataResponse keyedGetDataResponse = createTestResponse(1); + streamInfo.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(keyedGetDataResponse.toByteString()) + .build()); + assertThat(sendFuture1.join()).isEqualTo(keyedGetDataResponse); + + streamInfo.responseObserver.onCompleted(); + assertFalse(getDataStream.awaitTermination(0, TimeUnit.MILLISECONDS)); + + Windmill.KeyedGetDataResponse keyedGetDataResponse2 = createTestResponse(2); + streamInfo2.responseObserver.onNext( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(2) + .addSerializedResponse(keyedGetDataResponse2.toByteString()) + .build()); + assertThat(sendFuture2.join()).isEqualTo(keyedGetDataResponse2); + streamInfo2.responseObserver.onCompleted(); + streamInfo3.responseObserver.onCompleted(); + assertTrue(getDataStream.awaitTermination(10, TimeUnit.SECONDS)); + } + + @Test + public void testRequestKeyedData_raceShutdownDuringTrySendBatch() throws Exception { + AtomicBoolean connectedOnce = new AtomicBoolean(false); + CountDownLatch failedConnects = new CountDownLatch(2); + GrpcGetDataStream getDataStream = + (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setSendKeyedGetDataRequests(false) + .build() + .createGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel) + .withInterceptors( + new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor methodDescriptor, + CallOptions callOptions, + Channel channel) { + if (connectedOnce.getAndSet(true)) { + failedConnects.countDown(); + throw new RuntimeException("test error"); + } + return channel.newCall(methodDescriptor, callOptions); + } + })); + getDataStream.start(); + // Wait for the first stream to succeed and cause it to fail, the rest should fail. + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + streamInfo.responseObserver.onError(new RuntimeException("fake error")); + + failedConnects.await(); + + // Send while we're in this state. + // Create a request + Windmill.KeyedGetDataRequest keyedGetDataRequest = createTestRequest(1); + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + // The shutdown should work if it occurs either before or after the above request is sent. + Thread.sleep(100); getDataStream.shutdown(); + + // The request should complete with an exception, it may or may not get there. + assertThrows(CompletionException.class, sendFuture::join); + assertTrue(sendFuture.isCompletedExceptionally()); } private FakeWindmillGrpcService.GetDataStreamInfo waitForConnectionAndConsumeHeader() {