diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 968664a9b874..cf46e1f984dc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -26,8 +26,9 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -37,6 +38,7 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -48,12 +50,12 @@ * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a * broken stream. * - *

Subclasses should override {@link #onResponse(ResponseT)} to handle responses from the server, - * and {@link #onNewStream()} to perform any work that must be done when a new stream is created, - * such as sending headers or retrying requests. + *

Subclasses should override {@link #newResponseHandler()} to implement a handler for physical + * stream connection. {@link #onNewStream()} to perform any work that must be done when a new stream + * is created, such as sending headers or retrying requests. * - *

{@link #trySend(RequestT)} and {@link #startStream()} should not be called from {@link - * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead. + *

{@link #trySend(RequestT)} and {@link #startStream()} should not be called when handling + * responses; use {@link #executeSafely(Runnable)} instead. * *

Synchronization on this is used to synchronize the gRpc stream state and internal data * structures. Since grpc channel operations may block, synchronization on this stream may also @@ -83,9 +85,12 @@ public abstract class AbstractWindmillStream implements Win private final Set> streamRegistry; private final int logEveryNStreamFailures; private final String backendWorkerToken; + + private final Function, TerminatingStreamObserver> + physicalStreamFactory; + protected final long physicalStreamDeadlineSeconds; private final ResettableThrowingStreamObserver requestObserver; - private final Supplier> requestObserverFactory; private final StreamDebugMetrics debugMetrics; private final AtomicBoolean isHealthCheckScheduled; @@ -95,6 +100,17 @@ public abstract class AbstractWindmillStream implements Win @GuardedBy("this") protected boolean isShutdown; + // The active physical grpc stream. trySend will send messages on the bi-directional stream + // associated with this handler. The instances are created by subclasses via newResponseHandler. + // Subclasses may wish to store additional per-physical stream state within the handler. + @GuardedBy("this") + protected @Nullable PhysicalStreamHandler currentPhysicalStream; + + // Generally the same as currentPhysicalStream, set under synchronization of this but can be read + // without. + private final AtomicReference currentPhysicalStreamForDebug = + new AtomicReference<>(); + @GuardedBy("this") private boolean started; @@ -108,6 +124,9 @@ protected AbstractWindmillStream( int logEveryNStreamFailures, String backendWorkerToken) { this.backendWorkerToken = backendWorkerToken; + this.physicalStreamFactory = + (StreamObserver observer) -> streamObserverFactory.from(clientFactory, observer); + this.physicalStreamDeadlineSeconds = streamObserverFactory.getDeadlineSeconds(); this.executor = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() @@ -124,10 +143,6 @@ protected AbstractWindmillStream( this.finishLatch = new CountDownLatch(1); this.logger = logger; this.requestObserver = new ResettableThrowingStreamObserver<>(logger); - this.requestObserverFactory = - () -> - streamObserverFactory.from( - clientFactory, new AbstractWindmillStream.ResponseObserver()); this.sleeper = Sleeper.DEFAULT; this.debugMetrics = StreamDebugMetrics.create(); } @@ -138,19 +153,45 @@ private static String createThreadName(String streamType, String backendWorkerTo : String.format("%s-WindmillStream-thread", streamType); } - /** Called on each response from the server. */ - protected abstract void onResponse(ResponseT response); + /** Represents a physical grpc stream that is part of the logical windmill stream. */ + protected abstract class PhysicalStreamHandler { - /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream() throws WindmillStreamShutdownException; + /** Called on each response from the server. */ + public abstract void onResponse(ResponseT response); + + /** Returns whether there are any pending requests that should be retried on a stream break. */ + public abstract boolean hasPendingRequests(); - /** Returns whether there are any pending requests that should be retried on a stream break. */ - protected abstract boolean hasPendingRequests(); + /** + * Called when the physical stream has finished. For streams with requests that should be + * retried, requests should be moved to parent state so that it is captured by the next + * flushPendingToStream call. + */ + public abstract void onDone(Status status); + + /** + * Renders information useful for debugging as html. + * + * @implNote Don't require synchronization on AbstractWindmillStream.this, see the {@link + * #appendSummaryHtml(PrintWriter)} comment. + */ + public abstract void appendHtml(PrintWriter writer); + + private final StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.create(); + } + + protected abstract PhysicalStreamHandler newResponseHandler(); + + protected abstract void onNewStream() throws WindmillStreamShutdownException; /** Try to send a request to the server. Returns true if the request was successfully sent. */ @CanIgnoreReturnValue protected final synchronized boolean trySend(RequestT request) throws WindmillStreamShutdownException { + if (currentPhysicalStream == null) { + return false; + } + currentPhysicalStream.streamDebugMetrics.recordSend(); debugMetrics.recordSend(); try { requestObserver.onNext(request); @@ -182,10 +223,14 @@ private void startStream() { // Add the stream to the registry after it has been fully constructed. streamRegistry.add(this); while (true) { + @NonNull PhysicalStreamHandler streamHandler = newResponseHandler(); try { synchronized (this) { debugMetrics.recordStart(); - requestObserver.reset(requestObserverFactory.get()); + streamHandler.streamDebugMetrics.recordStart(); + currentPhysicalStream = streamHandler; + currentPhysicalStreamForDebug.set(currentPhysicalStream); + requestObserver.reset(physicalStreamFactory.apply(new ResponseObserver(streamHandler))); onNewStream(); if (clientClosed) { halfClose(); @@ -272,6 +317,23 @@ public final void maybeScheduleHealthCheck(Instant lastSendThreshold) { */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); + + @Nullable PhysicalStreamHandler currentHandler = currentPhysicalStreamForDebug.get(); + if (currentHandler != null) { + writer.format("Physical stream: "); + currentHandler.appendHtml(writer); + StreamDebugMetrics.Snapshot summaryMetrics = + currentHandler.streamDebugMetrics.getSummaryMetrics(); + if (summaryMetrics.isClientClosed()) { + writer.write(" client closed"); + } + writer.format( + " current stream is %dms old, last send %dms, last response %dms\n", + summaryMetrics.streamAge(), + summaryMetrics.timeSinceLastSend(), + summaryMetrics.timeSinceLastResponse()); + } + StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); summaryMetrics .restartMetrics() @@ -304,6 +366,8 @@ public final void appendSummaryHtml(PrintWriter writer) { } /** + * Add specific debug state for the logical stream. + * * @implNote Don't require synchronization on stream, see the {@link * #appendSummaryHtml(PrintWriter)} comment. */ @@ -315,6 +379,9 @@ public final synchronized void halfClose() { debugMetrics.recordHalfClose(); clientClosed = true; try { + if (currentPhysicalStream != null) { + currentPhysicalStream.streamDebugMetrics.recordHalfClose(); + } requestObserver.onCompleted(); } catch (ResettableThrowingStreamObserver.StreamClosedException e) { logger.warn("Stream was previously closed."); @@ -354,11 +421,17 @@ public final void shutdown() { } } - protected abstract void shutdownInternal(); + protected synchronized void shutdownInternal() {} /** Returns true if the stream was torn down and should not be restarted internally. */ - private synchronized boolean maybeTearDownStream() { - if (isShutdown || (clientClosed && !hasPendingRequests())) { + private synchronized boolean maybeTearDownStream(PhysicalStreamHandler doneStream) { + if (clientClosed && !doneStream.hasPendingRequests()) { + shutdown(); + } + + if (isShutdown) { + // Once we have background closing physicalStreams we will need to improve this to wait for + // all of the work of the logical stream to be complete. streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); @@ -369,23 +442,49 @@ private synchronized boolean maybeTearDownStream() { } private class ResponseObserver implements StreamObserver { + private final PhysicalStreamHandler handler; + + ResponseObserver(PhysicalStreamHandler handler) { + this.handler = handler; + } @Override public void onNext(ResponseT response) { backoff.reset(); debugMetrics.recordResponse(); - onResponse(response); + handler.streamDebugMetrics.recordResponse(); + handler.onResponse(response); } @Override public void onError(Throwable t) { - if (maybeTearDownStream()) { - return; - } + executeSafely(() -> onPhysicalStreamCompletion(Status.fromThrowable(t), handler)); + } - Status errorStatus = Status.fromThrowable(t); - recordStreamStatus(errorStatus); + @Override + public void onCompleted() { + executeSafely(() -> onPhysicalStreamCompletion(OK_STATUS, handler)); + } + } + + @SuppressWarnings("nullness") + private void clearPhysicalStreamForDebug() { + currentPhysicalStreamForDebug.set(null); + } + private void onPhysicalStreamCompletion(Status status, PhysicalStreamHandler handler) { + synchronized (this) { + if (currentPhysicalStream == handler) { + clearPhysicalStreamForDebug(); + currentPhysicalStream = null; + } + } + handler.onDone(status); + if (maybeTearDownStream(handler)) { + return; + } + // Backoff on errors.; + if (!status.isOk()) { try { long sleep = backoff.nextBackOffMillis(); debugMetrics.recordSleep(sleep); @@ -394,54 +493,43 @@ public void onError(Throwable t) { Thread.currentThread().interrupt(); return; } - - executeSafely(AbstractWindmillStream.this::startStream); - } - - @Override - public void onCompleted() { - if (maybeTearDownStream()) { - return; - } - recordStreamStatus(OK_STATUS); - executeSafely(AbstractWindmillStream.this::startStream); } + recordStreamRestart(status); + startStream(); + } - private void recordStreamStatus(Status status) { - int currentRestartCount = debugMetrics.incrementAndGetRestarts(); - if (status.isOk()) { - String restartReason = - "Stream completed successfully but did not complete requested operations, " - + "recreating"; - logger.warn(restartReason); - debugMetrics.recordRestartReason(restartReason); - } else { - int currentErrorCount = debugMetrics.incrementAndGetErrors(); - debugMetrics.recordRestartReason(status.toString()); - Throwable t = status.getCause(); - if (t instanceof StreamObserverCancelledException) { - logger.error( - "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}", - getClass(), - backendWorkerToken, - t.getStackTrace(), - t); - } else if (currentRestartCount % logEveryNStreamFailures == 0) { - // Don't log every restart since it will get noisy, and many errors transient. - long nowMillis = Instant.now().getMillis(); - logger.debug( - "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" - + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", - AbstractWindmillStream.this.getClass(), - currentRestartCount, - currentErrorCount, - t, - status, - nowMillis - debugMetrics.getStartTimeMs(), - debugMetrics - .responseDebugString(nowMillis) - .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); - } + private void recordStreamRestart(Status status) { + int currentRestartCount = debugMetrics.incrementAndGetRestarts(); + if (status.isOk()) { + String restartReason = + "Stream completed successfully but did not complete requested operations, " + + "recreating"; + logger.warn(restartReason); + debugMetrics.recordRestartReason(restartReason); + } else { + int currentErrorCount = debugMetrics.incrementAndGetErrors(); + debugMetrics.recordRestartReason(status.toString()); + Throwable t = status.getCause(); + if (t instanceof StreamObserverCancelledException) { + logger.error( + "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}", + getClass(), + backendWorkerToken, + t.getStackTrace(), + t); + } else if (currentRestartCount % logEveryNStreamFailures == 0) { + // Don't log every restart since it will get noisy, and many errors transient. + long nowMillis = Instant.now().getMillis(); + logger.debug( + "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", + AbstractWindmillStream.this.getClass(), + currentRestartCount, + currentErrorCount, + t, + status, + nowMillis - debugMetrics.getStartTimeMs(), + debugMetrics.responseDebugString(nowMillis).orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java index 35d6fb5ea100..11690ecf279f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java @@ -46,13 +46,15 @@ final class AppendableInputStream extends InputStream { private final AtomicLong blockedStartMs; private final BlockingDeque queue; private final InputStream stream; + private final long deadlineSeconds; - AppendableInputStream() { + AppendableInputStream(long deadlineSeconds) { this.cancelled = new AtomicBoolean(false); this.complete = new AtomicBoolean(false); this.blockedStartMs = new AtomicLong(); this.queue = new LinkedBlockingDeque<>(QUEUE_MAX_CAPACITY); this.stream = new SequenceInputStream(new InputStreamEnumeration()); + this.deadlineSeconds = deadlineSeconds; } long getBlockedStartMs() { @@ -71,6 +73,10 @@ int size() { return queue.size(); } + long getDeadlineSeconds() { + return deadlineSeconds; + } + /** Appends a new InputStream to the tail of this stream. */ synchronized void append(InputStream chunk) { try { @@ -155,7 +161,7 @@ public boolean hasMoreElements() { try { blockedStartMs.set(Instant.now().getMillis()); - current = queue.poll(180, TimeUnit.SECONDS); + current = queue.poll(deadlineSeconds, TimeUnit.SECONDS); if (current != null && current != POISON_PILL) { return true; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 5d6f965e2d22..7a7b1a5cd27e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import java.io.PrintWriter; @@ -44,6 +45,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; @@ -57,7 +59,18 @@ final class GrpcCommitWorkStream private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; - private final ConcurrentMap pending; + private static class StreamAndRequest { + StreamAndRequest(@Nullable CommitWorkPhysicalStreamHandler handler, PendingRequest request) { + this.handler = handler; + this.request = request; + } + + final @Nullable CommitWorkPhysicalStreamHandler handler; + final PendingRequest request; + } + + private final ConcurrentMap pending = new ConcurrentHashMap<>(); + private final AtomicLong idGenerator; private final JobHeader jobHeader; private final int streamingRpcBatchLimit; @@ -82,7 +95,6 @@ private GrpcCommitWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - pending = new ConcurrentHashMap<>(); this.idGenerator = idGenerator; this.jobHeader = jobHeader; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -119,12 +131,20 @@ public void appendSpecificHtml(PrintWriter writer) { @Override protected synchronized void onNewStream() throws WindmillStreamShutdownException { trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + // Flush all pending requests that are no longer on active streams. try (Batcher resendBatcher = new Batcher()) { - for (Map.Entry entry : pending.entrySet()) { - if (!resendBatcher.canAccept(entry.getValue().getBytes())) { + for (Map.Entry entry : pending.entrySet()) { + CommitWorkPhysicalStreamHandler requestHandler = entry.getValue().handler; + checkState(requestHandler != currentPhysicalStream); + // When we have streams closing in the background we should avoid retrying the requests + // active on those streams. + + long id = entry.getKey(); + PendingRequest request = entry.getValue().request; + if (!resendBatcher.canAccept(request.getBytes())) { resendBatcher.flush(); } - resendBatcher.add(entry.getKey(), entry.getValue()); + resendBatcher.add(id, request); } } } @@ -139,42 +159,38 @@ public CommitWorkStream.RequestBatcher batcher() { } @Override - protected boolean hasPendingRequests() { - return !pending.isEmpty(); - } - - @Override - protected void sendHealthCheck() throws WindmillStreamShutdownException { - if (hasPendingRequests()) { + protected synchronized void sendHealthCheck() throws WindmillStreamShutdownException { + if (currentPhysicalStream != null && currentPhysicalStream.hasPendingRequests()) { StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); trySend(builder.build()); } } - @Override - protected void onResponse(StreamingCommitResponse response) { - CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); - for (int i = 0; i < response.getRequestIdCount(); ++i) { - long requestId = response.getRequestId(i); - if (requestId == HEARTBEAT_REQUEST_ID) { - continue; - } + private class CommitWorkPhysicalStreamHandler extends PhysicalStreamHandler { + @Override + public void onResponse(StreamingCommitResponse response) { + CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); + for (int i = 0; i < response.getRequestIdCount(); ++i) { + long requestId = response.getRequestId(i); + if (requestId == HEARTBEAT_REQUEST_ID) { + continue; + } + + // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may + // be omitted. + CommitStatus commitStatus = + i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; - // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may - // be omitted. - CommitStatus commitStatus = - i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; - - @Nullable PendingRequest pendingRequest = pending.remove(requestId); - if (pendingRequest == null) { - synchronized (this) { - if (!isShutdown) { - // Missing responses are expected after shutdown() because it removes them. - LOG.error("Got unknown commit request ID: {}", requestId); - } + @Nullable StreamAndRequest entry = pending.remove(requestId); + if (entry == null) { + LOG.error("Got unknown commit request ID: {}", requestId); + continue; } - } else { + if (entry.handler != this) { + LOG.error("Got commit request id {} on unexpected stream", requestId); + } + PendingRequest pendingRequest = entry.request; try { pendingRequest.completeWithStatus(commitStatus); } catch (RuntimeException e) { @@ -185,16 +201,40 @@ protected void onResponse(StreamingCommitResponse response) { failureHandler.addError(commitStatus, e); } } + + failureHandler.throwIfNonEmpty(); + } + + @Override + public boolean hasPendingRequests() { + return pending.entrySet().stream().anyMatch(e -> e.getValue().handler == this); } - failureHandler.throwIfNonEmpty(); + @Override + public void onDone(Status status) { + if (status.isOk() && hasPendingRequests()) { + LOG.warn("Unexpected requests without responses on drained physical stream, retrying."); + } + } + + @Override + public void appendHtml(PrintWriter writer) { + writer.format( + "CommitWorkStream: %d pending", + pending.entrySet().stream().filter(e -> e.getValue().handler == this).count()); + } + } + + @Override + protected PhysicalStreamHandler newResponseHandler() { + return new CommitWorkPhysicalStreamHandler(); } @Override - protected void shutdownInternal() { - Iterator pendingRequests = pending.values().iterator(); + protected synchronized void shutdownInternal() { + Iterator pendingRequests = pending.values().iterator(); while (pendingRequests.hasNext()) { - PendingRequest pendingRequest = pendingRequests.next(); + PendingRequest pendingRequest = pendingRequests.next().request; pendingRequest.abort(); pendingRequests.remove(); } @@ -230,10 +270,14 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { - if (!prepareForSend(id, pendingRequest)) { + if (isShutdown) { pendingRequest.abort(); return; } + pending.put( + id, + new StreamAndRequest( + (CommitWorkPhysicalStreamHandler) currentPhysicalStream, pendingRequest)); trySend(chunk); } } @@ -256,10 +300,16 @@ private void issueBatchedRequest(Map requests) } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { - if (!prepareForSend(requests)) { + if (isShutdown) { requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); return; } + for (Map.Entry entry : requests.entrySet()) { + pending.put( + entry.getKey(), + new StreamAndRequest( + (CommitWorkPhysicalStreamHandler) currentPhysicalStream, entry.getValue())); + } trySend(request); } } @@ -269,11 +319,15 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { - if (!prepareForSend(id, pendingRequest)) { + if (isShutdown) { pendingRequest.abort(); return; } + pending.put( + id, + new StreamAndRequest( + (CommitWorkPhysicalStreamHandler) currentPhysicalStream, pendingRequest)); for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { @@ -300,24 +354,6 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) } } - /** Returns true if prepare for send succeeded. */ - private synchronized boolean prepareForSend(long id, PendingRequest request) { - if (!isShutdown) { - pending.put(id, request); - return true; - } - return false; - } - - /** Returns true if prepare for send succeeded. */ - private synchronized boolean prepareForSend(Map requests) { - if (!isShutdown) { - pending.putAll(requests); - return true; - } - return false; - } - @AutoValue abstract static class PendingRequest { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index f7960b7f68dc..938ec1c693c7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -46,6 +46,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,15 +81,6 @@ final class GrpcDirectGetWorkStream private final GetDataClient getDataClient; private final AtomicReference lastRequest; - /** - * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they - * come in. Once all chunks for a response has been received, the chunk is processed and the - * buffer is cleared. - * - * @implNote Buffers are not persisted across stream restarts. - */ - private final ConcurrentMap workItemAssemblers; - private final boolean requestBatchedGetWorkResponse; private GrpcDirectGetWorkStream( @@ -118,7 +110,6 @@ private GrpcDirectGetWorkStream( backendWorkerToken); this.requestHeader = requestHeader; this.workItemScheduler = workItemScheduler; - this.workItemAssemblers = new ConcurrentHashMap<>(); this.heartbeatSender = heartbeatSender; this.workCommitter = workCommitter; this.getDataClient = getDataClient; @@ -199,9 +190,47 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { } } + private class DirectGetWorkPhysicalStreamHandler extends PhysicalStreamHandler { + /** + * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they + * come in. Once all chunks for a response has been received, the chunk is processed and the + * buffer is cleared. + * + * @implNote Buffers are not persisted across stream restarts. + */ + final ConcurrentMap workItemAssemblers = + new ConcurrentHashMap<>(); + + @Override + public void onResponse(StreamingGetWorkResponseChunk response) { + workItemAssemblers + .computeIfAbsent(response.getStreamId(), unused -> new GetWorkResponseChunkAssembler()) + .append(response) + .forEach(GrpcDirectGetWorkStream.this::consumeAssembledWorkItem); + } + + @Override + public boolean hasPendingRequests() { + return false; + } + + @Override + public void onDone(Status status) {} + + @Override + public void appendHtml(PrintWriter writer) { + // Number of buffers is same as distinct workers that sent work on this stream. + writer.format("%d buffers", workItemAssemblers.size()); + } + } + + @Override + protected PhysicalStreamHandler newResponseHandler() { + return new DirectGetWorkPhysicalStreamHandler(); + } + @Override protected synchronized void onNewStream() throws WindmillStreamShutdownException { - workItemAssemblers.clear(); budgetTracker.reset(); GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); StreamingGetWorkRequest request = @@ -219,17 +248,9 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException trySend(request); } - @Override - protected boolean hasPendingRequests() { - return false; - } - @Override public void appendSpecificHtml(PrintWriter writer) { - // Number of buffers is same as distinct workers that sent work on this stream. - writer.format( - "GetWorkStream: %d buffers, " + "last sent request: %s; ", - workItemAssemblers.size(), lastRequest.get()); + writer.format("GetWorkStream: last sent request: %s; ", lastRequest.get()); writer.print(budgetTracker.debugString()); } @@ -238,17 +259,6 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException { trySend(HEALTH_CHECK_REQUEST); } - @Override - protected void shutdownInternal() {} - - @Override - protected void onResponse(StreamingGetWorkResponseChunk chunk) { - workItemAssemblers - .computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler()) - .append(chunk) - .forEach(this::consumeAssembledWorkItem); - } - private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 2ae58a60e7b1..7de074122a3c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -18,7 +18,9 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verifyNotNull; import java.io.IOException; import java.io.InputStream; @@ -34,7 +36,9 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; @@ -55,7 +59,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.BackOffUtils; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; +import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,10 +77,21 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ + static final FluentBackoff BACK_OFF_FACTORY = + FluentBackoff.DEFAULT + .withInitialBackoff(Duration.millis(10)) + .withMaxBackoff(Duration.standardSeconds(10)); + + /** + * @implNote {@link QueuedBatch} objects in the queue should also be guarded by {@code this}. + * Batches should be sent from the front of the queue and only removed from the queue once + * added to the pending set of a physical stream. + */ + @GuardedBy("this") private final Deque batches; - private final Map pending; + private final Supplier batchesDebugSizeSupplier; + private final AtomicLong idGenerator; private final JobHeader jobHeader; private final int streamingRpcBatchLimit; @@ -105,8 +125,11 @@ private GrpcGetDataStream( this.idGenerator = idGenerator; this.jobHeader = jobHeader; this.streamingRpcBatchLimit = streamingRpcBatchLimit; - this.batches = new ConcurrentLinkedDeque<>(); - this.pending = new ConcurrentHashMap<>(); + // A concurrent deque is used so that we can observe the size without synchronization on "this". + // Otherwise the deque is accessed via batches which has a guardedby annotation. + ConcurrentLinkedDeque batches = new ConcurrentLinkedDeque<>(); + this.batches = batches; + this.batchesDebugSizeSupplier = batches::size; this.sendKeyedGetDataRequests = sendKeyedGetDataRequests; this.processHeartbeatResponses = processHeartbeatResponses; } @@ -153,45 +176,109 @@ private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest) trySend(getDataRequest); } - @Override - protected synchronized void onNewStream() throws WindmillStreamShutdownException { - trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed) { - // We rely on close only occurring after all methods on the stream have returned. - // Since the requestKeyedData and requestGlobalData methods are blocking this - // means there should be no pending requests. - verify(!hasPendingRequests(), "Pending requests not expected if we've half-closed."); - } else { + class GetDataPhysicalStreamHandler extends PhysicalStreamHandler { + private final ConcurrentHashMap pending = + new ConcurrentHashMap<>(); + + public void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { + // Synchronization of pending inserts is necessary with send to ensure duplicates are not + // sent on stream reconnect. + for (QueuedRequest request : batch.requestsReadOnly()) { + boolean alreadyPresent = pending.put(request.id(), request.getResponseStream()) != null; + verify(!alreadyPresent, "Request already sent, id: %s", request.id()); + } + + if (!trySend(batch.asGetDataRequest())) { + // The stream broke before this call went through; onNewStream will retry the fetch. + LOG.debug("GetData stream broke before call started."); + } + } + + @Override + public void onResponse(StreamingGetDataResponse chunk) { + checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); + checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); + onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); + + for (int i = 0; i < chunk.getRequestIdCount(); ++i) { + long requestId = chunk.getRequestId(i); + boolean completeResponse = chunk.getRemainingBytesForResponse() == 0; + AppendableInputStream responseStream = + verifyNotNull( + completeResponse ? pending.remove(requestId) : pending.get(requestId), + "No pending response stream"); + responseStream.append(chunk.getSerializedResponse(i).newInput()); + if (completeResponse) { + responseStream.complete(); + } + } + } + + @Override + public boolean hasPendingRequests() { + return !pending.isEmpty(); + } + + @Override + public void onDone(Status status) { + if (status.isOk() && hasPendingRequests()) { + LOG.warn("Pending requests not expected on successful GetData stream flushing."); + } for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); } + pending.clear(); + } + + @Override + public void appendHtml(PrintWriter writer) { + writer.format("%d pending requests [", pending.size()); + for (Map.Entry entry : pending.entrySet()) { + writer.format("Stream %d ", entry.getKey()); + if (entry.getValue().isCancelled()) { + writer.append("cancelled "); + } + if (entry.getValue().isComplete()) { + writer.append("complete "); + } + int queueSize = entry.getValue().size(); + if (queueSize > 0) { + writer.format("%d queued responses ", queueSize); + } + long blockedMs = entry.getValue().getBlockedStartMs(); + if (blockedMs > 0) { + writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs); + } + } + writer.append("]"); } } @Override - protected boolean hasPendingRequests() { - return !pending.isEmpty() || !batches.isEmpty(); + protected PhysicalStreamHandler newResponseHandler() { + return new GetDataPhysicalStreamHandler(); } @Override - protected void onResponse(StreamingGetDataResponse chunk) { - checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); - checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); - onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); - - for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - if (responseStream == null) { - synchronized (this) { - // shutdown()/shutdownInternal() cleans up pending, else we expect a pending - // responseStream for every response. - verify(isShutdown, "No pending response stream"); - } - continue; - } - responseStream.append(chunk.getSerializedResponse(i).newInput()); - if (chunk.getRemainingBytesForResponse() == 0) { - responseStream.complete(); + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + while (!batches.isEmpty()) { + QueuedBatch batch = checkNotNull(batches.peekFirst()); + verify(!batch.isEmpty()); + if (!batch.isFinalized()) break; + try { + verify( + batch == batches.pollFirst(), + "Sent GetDataStream request batch removed before send() was complete."); + checkNotNull((GetDataPhysicalStreamHandler) currentPhysicalStream).sendBatch(batch); + // Notify all waiters with requests in this batch as well as the sender + // of the next batch (if one exists). + batch.notifySent(); + } catch (Exception e) { + LOG.debug("Batch failed to send on new stream", e); + // Free waiters if the send() failed. + batch.notifyFailed(); + throw e; } } } @@ -204,14 +291,17 @@ private long uniqueId() { public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) throws WindmillStreamShutdownException { return issueRequest( - QueuedRequest.forComputation(uniqueId(), computation, request), + QueuedRequest.forComputation( + uniqueId(), computation, request, physicalStreamDeadlineSeconds), KeyedGetDataResponse::parseFrom); } @Override public GlobalData requestGlobalData(GlobalDataRequest request) throws WindmillStreamShutdownException { - return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); + return issueRequest( + QueuedRequest.global(uniqueId(), request, physicalStreamDeadlineSeconds), + GlobalData::parseFrom); } @Override @@ -284,8 +374,8 @@ public void onHeartbeatResponse(List resp } @Override - protected void sendHealthCheck() throws WindmillStreamShutdownException { - if (hasPendingRequests()) { + protected synchronized void sendHealthCheck() throws WindmillStreamShutdownException { + if (currentPhysicalStream != null && currentPhysicalStream.hasPendingRequests()) { trySend(HEALTH_CHECK_REQUEST); } } @@ -294,8 +384,14 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException { protected synchronized void shutdownInternal() { // Stream has been explicitly closed. Drain pending input streams and request batches. // Future calls to send RPCs will fail. - pending.values().forEach(AppendableInputStream::cancel); - pending.clear(); + final @Nullable GetDataPhysicalStreamHandler currentGetDataStream = + (GetDataPhysicalStreamHandler) currentPhysicalStream; + if (currentGetDataStream != null) { + for (AppendableInputStream ais : currentGetDataStream.pending.values()) { + ais.cancel(); + } + currentGetDataStream.pending.clear(); + } batches.forEach( batch -> { batch.markFinalized(); @@ -306,30 +402,12 @@ protected synchronized void shutdownInternal() { @Override public void appendSpecificHtml(PrintWriter writer) { - writer.format( - "GetDataStream: %d queued batches, %d pending requests [", batches.size(), pending.size()); - for (Map.Entry entry : pending.entrySet()) { - writer.format("Stream %d ", entry.getKey()); - if (entry.getValue().isCancelled()) { - writer.append("cancelled "); - } - if (entry.getValue().isComplete()) { - writer.append("complete "); - } - int queueSize = entry.getValue().size(); - if (queueSize > 0) { - writer.format("%d queued responses ", queueSize); - } - long blockedMs = entry.getValue().getBlockedStartMs(); - if (blockedMs > 0) { - writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs); - } - } - writer.append("]"); + writer.format("GetDataStream: %d queued batches", batchesDebugSizeSupplier.get()); } private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) throws WindmillStreamShutdownException { + final BackOff backoff = BACK_OFF_FACTORY.backoff(); while (true) { request.resetResponseStream(); try { @@ -342,12 +420,15 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn { ResponseT parse(InputStream input) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index ef7f5b20bb07..318738893f0d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -51,25 +51,33 @@ static class QueuedRequest { private final ComputationOrGlobalDataRequest dataRequest; private AppendableInputStream responseStream; - private QueuedRequest(long id, ComputationOrGlobalDataRequest dataRequest) { + private QueuedRequest( + long id, ComputationOrGlobalDataRequest dataRequest, long deadlineSeconds) { this.id = id; this.dataRequest = dataRequest; - responseStream = new AppendableInputStream(); + responseStream = new AppendableInputStream(deadlineSeconds); } static QueuedRequest forComputation( - long id, String computation, KeyedGetDataRequest keyedGetDataRequest) { + long id, + String computation, + KeyedGetDataRequest keyedGetDataRequest, + long deadlineSeconds) { ComputationGetDataRequest computationGetDataRequest = ComputationGetDataRequest.newBuilder() .setComputationId(computation) .addRequests(keyedGetDataRequest) .build(); return new QueuedRequest( - id, ComputationOrGlobalDataRequest.computation(computationGetDataRequest)); + id, + ComputationOrGlobalDataRequest.computation(computationGetDataRequest), + deadlineSeconds); } - static QueuedRequest global(long id, GlobalDataRequest globalDataRequest) { - return new QueuedRequest(id, ComputationOrGlobalDataRequest.global(globalDataRequest)); + static QueuedRequest global( + long id, GlobalDataRequest globalDataRequest, long deadlineSeconds) { + return new QueuedRequest( + id, ComputationOrGlobalDataRequest.global(globalDataRequest), deadlineSeconds); } static Comparator globalRequestsFirst() { @@ -93,7 +101,7 @@ AppendableInputStream getResponseStream() { } void resetResponseStream() { - this.responseStream = new AppendableInputStream(); + this.responseStream = new AppendableInputStream(responseStream.getDeadlineSeconds()); } public ComputationOrGlobalDataRequest getDataRequest() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 63342b1f6124..a1c758eac446 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import java.io.PrintWriter; -import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; @@ -35,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,7 +52,6 @@ final class GrpcGetWorkStream private final GetWorkRequest request; private final WorkItemReceiver receiver; - private final Map workItemAssemblers; private final AtomicLong inflightMessages; private final AtomicLong inflightBytes; private final boolean requestBatchedGetWorkResponse; @@ -81,7 +80,6 @@ private GrpcGetWorkStream( backendWorkerToken); this.request = request; this.receiver = receiver; - this.workItemAssemblers = new ConcurrentHashMap<>(); this.inflightMessages = new AtomicLong(); this.inflightBytes = new AtomicLong(); this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse; @@ -131,9 +129,41 @@ private void sendRequestExtension(long moreItems, long moreBytes) { }); } + private class GetWorkPhysicalStreamHandler extends PhysicalStreamHandler { + + private final ConcurrentHashMap workItemAssemblers = + new ConcurrentHashMap<>(); + + @Override + public void onResponse(StreamingGetWorkResponseChunk response) { + workItemAssemblers + .computeIfAbsent(response.getStreamId(), unused -> new GetWorkResponseChunkAssembler()) + .append(response) + .forEach(GrpcGetWorkStream.this::consumeAssembledWorkItem); + } + + @Override + public boolean hasPendingRequests() { + return false; + } + + @Override + public void onDone(Status status) {} + + @Override + public void appendHtml(PrintWriter writer) { + // Number of buffers is same as distinct workers that sent work on this stream. + writer.format("%d buffers", workItemAssemblers.size()); + } + } + + @Override + protected PhysicalStreamHandler newResponseHandler() { + return new GetWorkPhysicalStreamHandler(); + } + @Override protected synchronized void onNewStream() throws WindmillStreamShutdownException { - workItemAssemblers.clear(); inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); trySend( @@ -143,20 +173,11 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException .build()); } - @Override - protected void shutdownInternal() {} - - @Override - protected boolean hasPendingRequests() { - return false; - } - @Override public void appendSpecificHtml(PrintWriter writer) { - // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed", - workItemAssemblers.size(), inflightMessages.intValue(), inflightBytes.intValue()); + "GetWorkStream: %d inflight messages allowed, %d inflight bytes allowed", + inflightMessages.intValue(), inflightBytes.intValue()); } @Override @@ -164,14 +185,6 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException { trySend(HEALTH_CHECK); } - @Override - protected void onResponse(StreamingGetWorkResponseChunk chunk) { - workItemAssemblers - .computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler()) - .append(chunk) - .forEach(this::consumeAssembledWorkItem); - } - private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { receiver.receiveWork( assembledWorkItem.computationMetadata().computationId(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index cb30a3e6a500..9b99b3bda909 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -32,6 +32,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,15 +93,6 @@ public static GrpcGetWorkerMetadataStream create( serverMappingUpdater); } - /** - * Each instance of {@link AbstractWindmillStream} owns its own responseObserver that calls - * onResponse(). - */ - @Override - protected void onResponse(WorkerMetadataResponse response) { - extractWindmillEndpointsFrom(response).ifPresent(serverMappingConsumer); - } - /** * Acquires the {@link #metadataLock} Returns {@link Optional} if the * metadataVersion in the response is not stale (older or equal to current {@link @@ -127,16 +119,30 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected void onNewStream() throws WindmillStreamShutdownException { - trySend(workerMetadataRequest); - } + protected PhysicalStreamHandler newResponseHandler() { + return new PhysicalStreamHandler() { - @Override - protected void shutdownInternal() {} + @Override + public void onResponse(WorkerMetadataResponse response) { + extractWindmillEndpointsFrom(response).ifPresent(serverMappingConsumer); + } + + @Override + public boolean hasPendingRequests() { + return false; + } + + @Override + public void onDone(Status status) {} + + @Override + public void appendHtml(PrintWriter writer) {} + }; + } @Override - protected boolean hasPendingRequests() { - return false; + protected void onNewStream() throws WindmillStreamShutdownException { + trySend(workerMetadataRequest); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java index b4f3e854bf19..e97416c004fa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java @@ -37,6 +37,8 @@ public abstract TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver responseObserver); + public abstract long getDeadlineSeconds(); + private static class Direct extends StreamObserverFactory { private final long deadlineSeconds; private final int messagesBetweenIsReadyChecks; @@ -59,5 +61,10 @@ public TerminatingStreamObserver from( return new DirectStreamObserver<>( phaser, outboundObserver, deadlineSeconds, messagesBetweenIsReadyChecks); } + + @Override + public long getDeadlineSeconds() { + return deadlineSeconds; + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index bead1ffd5b24..c1696d8a70ab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -4190,7 +4190,9 @@ public void processElement(ProcessContext c) { static class TestExceptionFn extends DoFn { - boolean firstTime = true; + // Note that the use of static works because this DoFn is only used in a single test. We need + // to use static as the DoFn is not cached after user-code exceptions. + static boolean firstTime = true; @ProcessElement public void processElement(ProcessContext c) throws Exception { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index d9a1397e1c7c..92c081591c73 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -32,6 +32,7 @@ import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; @@ -157,16 +158,28 @@ private TestStream( } @Override - protected void onResponse(Integer response) {} + protected PhysicalStreamHandler newResponseHandler() { + return new PhysicalStreamHandler() { - @Override - protected void onNewStream() { - numStarts.incrementAndGet(); + @Override + public void onResponse(Integer response) {} + + @Override + public boolean hasPendingRequests() { + return false; + } + + @Override + public void onDone(Status status) {} + + @Override + public void appendHtml(PrintWriter writer) {} + }; } @Override - protected boolean hasPendingRequests() { - return false; + protected void onNewStream() { + numStarts.incrementAndGet(); } private void testSend() throws WindmillStreamShutdownException { @@ -192,9 +205,6 @@ private void waitForHealthChecks(int expectedHealthChecks) { @Override protected void appendSpecificHtml(PrintWriter writer) {} - - @Override - protected void shutdownInternal() {} } private static class TestCallStreamObserver extends CallStreamObserver { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index fdd213223987..8cb831f15f62 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -43,7 +43,11 @@ public class WindmillStreamPoolTest { private final ConcurrentHashMap< TestWindmillStream, WindmillStreamPool.StreamData> holds = new ConcurrentHashMap<>(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + + @Rule + public transient Timeout globalTimeout = + Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build(); + private List> streams; @Before diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java new file mode 100644 index 000000000000..85c3c71663f1 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; +import org.hamcrest.Matchers; +import org.junit.rules.ErrorCollector; + +class FakeWindmillGrpcService + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final ErrorCollector errorCollector; + + @GuardedBy("this") + private boolean failOnNewStreams = false; + + public FakeWindmillGrpcService(ErrorCollector errorCollector) { + this.errorCollector = errorCollector; + } + + public static class StreamInfo { + public StreamInfo(StreamObserver responseObserver) { + this.responseObserver = responseObserver; + } + + public final StreamObserver responseObserver; + public final BlockingQueue requests = new LinkedBlockingQueue<>(1000); + public final CompletableFuture onDone = new CompletableFuture<>(); + }; + + private final BlockingQueue commitStreams = new LinkedBlockingQueue<>(1000); + private final BlockingQueue getDataStreams = new LinkedBlockingQueue<>(1000); + + private static class StreamInfoObserver implements StreamObserver { + private final StreamInfo streamInfo; + private final ErrorCollector errorCollector; + + public StreamInfoObserver( + StreamInfo streamInfo, ErrorCollector errorCollector) { + this.streamInfo = streamInfo; + this.errorCollector = errorCollector; + } + + @Override + public void onNext(RequestT request) { + errorCollector.checkThat(streamInfo.requests.add(request), Matchers.is(true)); + } + + @Override + public void onError(Throwable throwable) { + streamInfo.onDone.complete(throwable); + } + + @Override + public void onCompleted() { + streamInfo.onDone.complete(null); + } + } + + public static class CommitStreamInfo + extends StreamInfo { + CommitStreamInfo(StreamObserver responseObserver) { + super(responseObserver); + } + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + CommitStreamInfo info = new CommitStreamInfo(responseObserver); + synchronized (this) { + errorCollector.checkThat(failOnNewStreams, Matchers.is(false)); + errorCollector.checkThat(commitStreams.offer(info), Matchers.is(true)); + } + return new StreamInfoObserver<>(info, errorCollector); + } + + public CommitStreamInfo waitForConnectedCommitStream() throws InterruptedException { + return commitStreams.take(); + } + + public synchronized void expectNoMoreStreams() { + failOnNewStreams = true; + errorCollector.checkThat(commitStreams.isEmpty(), Matchers.is(true)); + errorCollector.checkThat(getDataStreams.isEmpty(), Matchers.is(true)); + } + + public static class GetDataStreamInfo + extends StreamInfo { + GetDataStreamInfo(StreamObserver responseObserver) { + super(responseObserver); + } + } + + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + GetDataStreamInfo info = new GetDataStreamInfo(responseObserver); + synchronized (this) { + errorCollector.checkThat(failOnNewStreams, Matchers.is(false)); + errorCollector.checkThat(getDataStreams.offer(info), Matchers.is(true)); + } + return new StreamInfoObserver<>(info, errorCollector); + } + + public GetDataStreamInfo waitForConnectedGetDataStream() throws InterruptedException { + return getDataStreams.take(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 31a851a702ee..7619668220d4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -18,18 +18,20 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.spy; import java.io.IOException; import java.util.HashSet; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; @@ -39,21 +41,20 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.ServerCallStreamObserver; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ErrorCollector; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.InOrder; @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() @@ -63,10 +64,20 @@ public class GrpcCommitWorkStreamTest { .build(); private static final String COMPUTATION_ID = "computationId"; + @SuppressWarnings("InlineMeInliner") // inline `Strings.repeat()` - Java 11+ API only + private static final ByteString LARGE_BYTE_STRING = + ByteString.copyFromUtf8(Strings.repeat("a", 2 * 1024 * 1024)); + + @Rule public final ErrorCollector errorCollector = new ErrorCollector(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + + @Rule + public transient Timeout globalTimeout = + Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build(); + + private final FakeWindmillGrpcService fakeService = new FakeWindmillGrpcService(errorCollector); private ManagedChannel inProcessChannel; + private Server inProcessServer; private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) { return Windmill.WorkItemCommitRequest.newBuilder() @@ -79,27 +90,25 @@ private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) @Before public void setUp() throws IOException { - Server server = - InProcessServerBuilder.forName(FAKE_SERVER_NAME) - .fallbackHandlerRegistry(serviceRegistry) - .directExecutor() - .build() - .start(); - + inProcessServer = + grpcCleanup.register( + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .addService(fakeService) + .directExecutor() + .build() + .start()); inProcessChannel = grpcCleanup.register( InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); - grpcCleanup.register(server); - grpcCleanup.register(inProcessChannel); } @After public void cleanUp() { + inProcessServer.shutdownNow(); inProcessChannel.shutdownNow(); } - private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub testStub) { - serviceRegistry.addService(testStub); + private GrpcCommitWorkStream createCommitWorkStream() { GrpcCommitWorkStream commitWorkStream = (GrpcCommitWorkStream) GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) @@ -110,16 +119,12 @@ private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub tes } @Test - public void testShutdown_abortsQueuedCommits() throws InterruptedException { + public void testShutdown_abortsActiveCommits() throws InterruptedException, ExecutionException { int numCommits = 5; CountDownLatch commitProcessed = new CountDownLatch(numCommits); Set onDone = new HashSet<>(); - TestCommitWorkStreamRequestObserver requestObserver = - spy(new TestCommitWorkStreamRequestObserver()); - CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); - GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); - InOrder requestObserverVerifier = inOrder(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { batcher.commitWorkItem( @@ -133,120 +138,298 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { } catch (StreamObserverCancelledException ignored) { } - // Verify that we sent the commits above in a request + the initial header. - requestObserverVerifier - .verify(requestObserver) - .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); - requestObserverVerifier - .verify(requestObserver) - .onNext(argThat(request -> !request.getCommitChunkList().isEmpty())); - requestObserverVerifier.verifyNoMoreInteractions(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + // The next request should have some chunks. + assertThat(streamInfo.requests.take().getCommitChunkList()).isNotEmpty(); // We won't get responses so we will have some pending requests. - assertTrue(commitWorkStream.hasPendingRequests()); + assertThat(commitProcessed.getCount()).isGreaterThan(0); commitWorkStream.shutdown(); + streamInfo.onDone.get(); + commitProcessed.await(); assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED); } @Test - public void testCommitWorkItem_afterShutdown() { + public void testCommitWorkItem_abortsCommitsSentAfterShutdown() + throws InterruptedException, ExecutionException { int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); - CommitWorkStreamTestStub testStub = - new CommitWorkStreamTestStub(new TestCommitWorkStreamRequestObserver()); - GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + commitWorkStream.shutdown(); + assertNotNull(streamInfo.onDone.get()); + AtomicBoolean allAborted = new AtomicBoolean(true); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { - assertTrue(batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + (status) -> { + if (status != Windmill.CommitStatus.ABORTED) { + allAborted.set(false); + } + commitProcessed.countDown(); + })); } } - commitWorkStream.shutdown(); + commitProcessed.await(); + assertTrue(allAborted.get()); + } + + @Test + public void testCommitWorkItem_retryOnNewStream() throws Exception { + int numCommits = 5; + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); - AtomicReference commitStatus = new AtomicReference<>(); + final AtomicBoolean allOk = new AtomicBoolean(true); + final CountDownLatch firstResponsesDone = new CountDownLatch(2); + final CountDownLatch secondResponsesDone = new CountDownLatch(3); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { + int finalI = i; assertTrue( - batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set)); + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + (status) -> { + if (status != Windmill.CommitStatus.OK) { + allOk.set(false); + } + if (finalI == 0 || finalI == 4) { + firstResponsesDone.countDown(); + } else { + secondResponsesDone.countDown(); + } + })); } } + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertEquals(5, request.getCommitChunkCount()); + for (int i = 0; i < 5; ++i) { + assertEquals(i + 1, request.getCommitChunk(i).getRequestId()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(i).getSerializedWorkItemCommit()); + assertEquals(parsedRequest.getWorkToken(), i); + } + // Send back that 1 and 5 finished. + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).addRequestId(5).build()); + firstResponsesDone.await(); + + // Simulate that the server breaks. + streamInfo.responseObserver.onError(new IOException("test error")); + + // The stream should reconnect and retry the requests. + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take(); + assertEquals(3, reconnectRequest.getCommitChunkCount()); + for (int i = 0; i < 3; ++i) { + assertEquals(i + 2, reconnectRequest.getCommitChunk(i).getRequestId()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + reconnectRequest.getCommitChunk(i).getSerializedWorkItemCommit()); + assertEquals(i + 1, parsedRequest.getWorkToken()); + } + // Send back that 2 and 3 finished. + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).addRequestId(3).build()); + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(4).build()); + secondResponsesDone.await(); - assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED); + assertThat(reconnectStreamInfo.requests).isEmpty(); + assertThat(streamInfo.requests).isEmpty(); + assertTrue(allOk.get()); } @Test - public void testSend_notCalledAfterShutdown() { + public void testCommitWorkItem_retryOnNewStreamHalfClose() throws Exception { int numCommits = 5; - CountDownLatch commitProcessed = new CountDownLatch(numCommits); - - TestCommitWorkStreamRequestObserver requestObserver = - spy(new TestCommitWorkStreamRequestObserver()); - InOrder requestObserverVerifier = inOrder(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); - CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); - GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + final AtomicBoolean allOk = new AtomicBoolean(true); + final CountDownLatch firstResponsesDone = new CountDownLatch(2); + final CountDownLatch secondResponsesDone = new CountDownLatch(3); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { + int finalI = i; assertTrue( batcher.commitWorkItem( COMPUTATION_ID, workItemCommitRequest(i), - commitStatus -> commitProcessed.countDown())); + (status) -> { + if (status != Windmill.CommitStatus.OK) { + allOk.set(false); + } + if (finalI == 0 || finalI == 4) { + firstResponsesDone.countDown(); + } else { + secondResponsesDone.countDown(); + } + })); } + } + Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take(); + assertEquals(5, request.getCommitChunkCount()); + for (int i = 0; i < 5; ++i) { + assertEquals(i + 1, request.getCommitChunk(i).getRequestId()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + request.getCommitChunk(i).getSerializedWorkItemCommit()); + assertEquals(parsedRequest.getWorkToken(), i); + } + // Half-close the logical stream. This shouldn't prevent reconnection of the physical stream + // from succeeding. + commitWorkStream.halfClose(); + assertNull(streamInfo.onDone.get()); + + // Send back that 1 and 5 finished. + streamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).addRequestId(5).build()); + firstResponsesDone.await(); + + // Simulate that the server breaks. + streamInfo.responseObserver.onError(new IOException("test error")); + + // The stream should reconnect and retry the requests. + FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo = + waitForConnectionAndConsumeHeader(); + + // We don't expect any more streams since we finish successfully below. + fakeService.expectNoMoreStreams(); + + Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take(); + assertEquals(3, reconnectRequest.getCommitChunkCount()); + for (int i = 0; i < 3; ++i) { + assertEquals(i + 2, reconnectRequest.getCommitChunk(i).getRequestId()); + Windmill.WorkItemCommitRequest parsedRequest = + Windmill.WorkItemCommitRequest.parseFrom( + reconnectRequest.getCommitChunk(i).getSerializedWorkItemCommit()); + assertEquals(i + 1, parsedRequest.getWorkToken()); + } + assertNull(streamInfo.onDone.get()); + + // Send back that 2 and 3 finished and then 4 finishes. + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).addRequestId(3).build()); + reconnectStreamInfo.responseObserver.onNext( + Windmill.StreamingCommitResponse.newBuilder().addRequestId(4).build()); + // The half-close completes + reconnectStreamInfo.responseObserver.onCompleted(); + secondResponsesDone.await(); + + assertThat(reconnectStreamInfo.requests).isEmpty(); + assertThat(streamInfo.requests).isEmpty(); + assertTrue(allOk.get()); + } + + @Test + public void testSend_notCalledAfterShutdown_Single() + throws ExecutionException, InterruptedException { + int numCommits = 1; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(0), + commitStatus -> { + errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED)); + errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L)); + commitProcessed.countDown(); + })); // Shutdown the stream before we exit the try-with-resources block which will try to send() // the batched request. commitWorkStream.shutdown(); } + commitProcessed.await(); - // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends - // the header, which happens before we shutdown. - requestObserverVerifier - .verify(requestObserver) - .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); - requestObserverVerifier.verify(requestObserver).onError(any()); - requestObserverVerifier.verifyNoMoreInteractions(); + assertNotNull(streamInfo.onDone.get()); + assertThat(streamInfo.requests).isEmpty(); } - private static class TestCommitWorkStreamRequestObserver - implements StreamObserver { - private @Nullable StreamObserver responseObserver; - - @Override - public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} - - @Override - public void onError(Throwable throwable) {} + @Test + public void testSend_notCalledAfterShutdown_Batch() + throws ExecutionException, InterruptedException { + int numCommits = 2; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); - @Override - public void onCompleted() { - if (responseObserver != null) { - responseObserver.onCompleted(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> { + errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED)); + errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L)); + commitProcessed.countDown(); + })); } + // Shutdown the stream before we exit the try-with-resources block which will try to send() + // the batched request. + commitWorkStream.shutdown(); } + commitProcessed.await(); + + assertNotNull(streamInfo.onDone.get()); + assertThat(streamInfo.requests).isEmpty(); } - private static class CommitWorkStreamTestStub - extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { - private final TestCommitWorkStreamRequestObserver requestObserver; - private @Nullable StreamObserver responseObserver; + @Test + public void testSend_notCalledAfterShutdown_Multichunk() + throws ExecutionException, InterruptedException { + int numCommits = 1; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(); + FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); - private CommitWorkStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { - this.requestObserver = requestObserver; + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(0) + .toBuilder() + .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build()) + .build(), + commitStatus -> { + errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED)); + errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L)); + commitProcessed.countDown(); + })); + // Shutdown the stream before we exit the try-with-resources block which will try to send() + // the batched request. + commitWorkStream.shutdown(); } + commitProcessed.await(); + assertNotNull(streamInfo.onDone.get()); + assertThat(streamInfo.requests).isEmpty(); + } - @Override - public StreamObserver commitWorkStream( - StreamObserver responseObserver) { - if (this.responseObserver == null) { - ((ServerCallStreamObserver) responseObserver) - .setOnCancelHandler(() -> {}); - this.responseObserver = responseObserver; - requestObserver.responseObserver = this.responseObserver; - } - - return requestObserver; + private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() { + try { + FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream(); + Windmill.StreamingCommitWorkRequest request = info.requests.take(); + errorCollector.checkThat(request.getHeader(), is(TEST_JOB_HEADER)); + assertEquals(0, request.getCommitChunkCount()); + return info; + } catch (InterruptedException e) { + throw new RuntimeException(e); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java index 9b0cc006464f..1014242317de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -81,7 +81,11 @@ public class GrpcDirectGetWorkStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + + @Rule + public transient Timeout globalTimeout = + Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build(); + private ManagedChannel inProcessChannel; private GrpcDirectGetWorkStream stream; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java index 8a23b4f51b5a..150db4ed4815 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -36,6 +36,8 @@ @RunWith(JUnit4.class) public class GrpcGetDataStreamRequestsTest { + private static final int DEADLINE_SECONDS = 10; + @Test public void testQueuedRequest_globalRequestsFirstComparator() { List requests = new ArrayList<>(); @@ -49,7 +51,7 @@ public void testQueuedRequest_globalRequestsFirstComparator() { .build(); requests.add( GrpcGetDataStreamRequests.QueuedRequest.forComputation( - 1, "computation1", keyedGetDataRequest1)); + 1, "computation1", keyedGetDataRequest1, DEADLINE_SECONDS)); Windmill.KeyedGetDataRequest keyedGetDataRequest2 = Windmill.KeyedGetDataRequest.newBuilder() @@ -61,7 +63,7 @@ public void testQueuedRequest_globalRequestsFirstComparator() { .build(); requests.add( GrpcGetDataStreamRequests.QueuedRequest.forComputation( - 2, "computation2", keyedGetDataRequest2)); + 2, "computation2", keyedGetDataRequest2, DEADLINE_SECONDS)); Windmill.GlobalDataRequest globalDataRequest = Windmill.GlobalDataRequest.newBuilder() @@ -72,7 +74,8 @@ public void testQueuedRequest_globalRequestsFirstComparator() { .build()) .setComputationId("computation1") .build(); - requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest, DEADLINE_SECONDS)); requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst()); @@ -94,7 +97,7 @@ public void testQueuedBatch_asGetDataRequest() { .build(); queuedBatch.addRequest( GrpcGetDataStreamRequests.QueuedRequest.forComputation( - 1, "computation1", keyedGetDataRequest1)); + 1, "computation1", keyedGetDataRequest1, DEADLINE_SECONDS)); Windmill.KeyedGetDataRequest keyedGetDataRequest2 = Windmill.KeyedGetDataRequest.newBuilder() @@ -106,7 +109,7 @@ public void testQueuedBatch_asGetDataRequest() { .build(); queuedBatch.addRequest( GrpcGetDataStreamRequests.QueuedRequest.forComputation( - 2, "computation2", keyedGetDataRequest2)); + 2, "computation2", keyedGetDataRequest2, DEADLINE_SECONDS)); Windmill.GlobalDataRequest globalDataRequest = Windmill.GlobalDataRequest.newBuilder() @@ -117,7 +120,8 @@ public void testQueuedBatch_asGetDataRequest() { .build()) .setComputationId("computation1") .build(); - queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest, DEADLINE_SECONDS)); Windmill.StreamingGetDataRequest getDataRequest = queuedBatch.asGetDataRequest(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index 4cf21c2adfdf..e954f2cc7105 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -18,9 +18,11 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.IOException; import java.util.List; @@ -32,7 +34,6 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; -import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; @@ -41,16 +42,13 @@ import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.ServerCallStreamObserver; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.util.MutableHandlerRegistry; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ErrorCollector; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -65,34 +63,38 @@ public class GrpcGetDataStreamTest { .setProjectId("test_project") .build(); + @Rule public final ErrorCollector errorCollector = new ErrorCollector(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + + @Rule + public transient Timeout globalTimeout = + Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build(); + + private final FakeWindmillGrpcService fakeService = new FakeWindmillGrpcService(errorCollector); private ManagedChannel inProcessChannel; + private Server inProcessServer; @Before public void setUp() throws IOException { - Server server = - InProcessServerBuilder.forName(FAKE_SERVER_NAME) - .fallbackHandlerRegistry(serviceRegistry) - .directExecutor() - .build() - .start(); - + inProcessServer = + grpcCleanup.register( + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .addService(fakeService) + .directExecutor() + .build() + .start()); inProcessChannel = grpcCleanup.register( InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); - grpcCleanup.register(server); - grpcCleanup.register(inProcessChannel); } @After public void cleanUp() { + inProcessServer.shutdownNow(); inProcessChannel.shutdownNow(); } - private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { - serviceRegistry.addService(testStub); + private GrpcGetDataStream createGetDataStream() { GrpcGetDataStream getDataStream = (GrpcGetDataStream) GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) @@ -104,53 +106,51 @@ private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { } @Test - public void testRequestKeyedData() { - GetDataStreamTestStub testStub = - new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); - GrpcGetDataStream getDataStream = createGetDataStream(testStub); + public void testRequestKeyedData() throws InterruptedException { + GrpcGetDataStream getDataStream = createGetDataStream(); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build(); + CompletableFuture sendFuture = CompletableFuture.supplyAsync( () -> { try { - return getDataStream.requestKeyedData( - "computationId", - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(ByteString.EMPTY) - .setShardingKey(1) - .setCacheToken(1) - .setWorkToken(1) - .build()); + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); } catch (WindmillStreamShutdownException e) { throw new RuntimeException(e); } }); - // Sleep a bit to allow future to run. - Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); - Windmill.KeyedGetDataResponse response = + Windmill.KeyedGetDataResponse keyedGetDataResponse = Windmill.KeyedGetDataResponse.newBuilder() .setShardingKey(1) .setKey(ByteString.EMPTY) .build(); - testStub.injectResponse( + streamInfo.responseObserver.onNext( Windmill.StreamingGetDataResponse.newBuilder() .addRequestId(1) - .addSerializedResponse(response.toByteString()) - .setRemainingBytesForResponse(0) + .addSerializedResponse(keyedGetDataResponse.toByteString()) .build()); - assertThat(sendFuture.join()).isEqualTo(response); + assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse); } @Test - @Ignore("https://github.com/apache/beam/issues/28957") public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() { - GetDataStreamTestStub testStub = - new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); - GrpcGetDataStream getDataStream = createGetDataStream(testStub); + GrpcGetDataStream getDataStream = createGetDataStream(); int numSendThreads = 5; ExecutorService getDataStreamSenders = Executors.newFixedThreadPool(numSendThreads); CountDownLatch waitForSendAttempt = new CountDownLatch(1); @@ -210,48 +210,106 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow } } - private static class TestGetDataStreamRequestObserver - implements StreamObserver { - private @Nullable StreamObserver responseObserver; + @Test + public void testRequestKeyedData_reconnectOnStreamError() throws InterruptedException { + GrpcGetDataStream getDataStream = createGetDataStream(); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); + + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build(); + + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); - @Override - public void onNext(Windmill.StreamingGetDataRequest streamingGetDataRequest) {} + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); - @Override - public void onError(Throwable throwable) {} + // Simulate an error on the grpc stream, this should trigger a retry of the request internal to + // the stream. + streamInfo.responseObserver.onError(new IOException("test error")); - @Override - public void onCompleted() { - if (responseObserver != null) { - responseObserver.onCompleted(); + streamInfo = waitForConnectionAndConsumeHeader(); + while (true) { + request = streamInfo.requests.poll(5, TimeUnit.SECONDS); + if (request != null) break; + if (sendFuture.isDone()) { + fail("Unexpected send completion " + sendFuture); } } + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + getDataStream.shutdown(); + assertThrows(RuntimeException.class, sendFuture::join); } - private static class GetDataStreamTestStub - extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { - private final TestGetDataStreamRequestObserver requestObserver; - private @Nullable StreamObserver responseObserver; + @Test + public void testRequestKeyedData_reconnectOnStreamErrorAfterHalfClose() + throws InterruptedException, ExecutionException { + GrpcGetDataStream getDataStream = createGetDataStream(); + FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader(); - private GetDataStreamTestStub(TestGetDataStreamRequestObserver requestObserver) { - this.requestObserver = requestObserver; - } + // These will block until they are successfully sent. + Windmill.KeyedGetDataRequest keyedGetDataRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build(); - @Override - public StreamObserver getDataStream( - StreamObserver responseObserver) { - if (this.responseObserver == null) { - ((ServerCallStreamObserver) responseObserver) - .setOnCancelHandler(() -> {}); - this.responseObserver = responseObserver; - requestObserver.responseObserver = this.responseObserver; - } + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData("computationId", keyedGetDataRequest); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); - return requestObserver; - } + Windmill.StreamingGetDataRequest request = streamInfo.requests.take(); + assertThat(request.getRequestIdList()).containsExactly(1L); + assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0)); + + // Close the stream. + getDataStream.halfClose(); + assertNull(streamInfo.onDone.get()); + + // Simulate an error on the grpc stream, this should trigger an error on all + // existing requests but no new connection since we half-closed and nothing left after + // responding with errors. + fakeService.expectNoMoreStreams(); + streamInfo.responseObserver.onError(new IOException("test error")); + assertThrows(RuntimeException.class, sendFuture::join); + + getDataStream.shutdown(); + } - private void injectResponse(Windmill.StreamingGetDataResponse getDataResponse) { - checkNotNull(responseObserver).onNext(getDataResponse); + private FakeWindmillGrpcService.GetDataStreamInfo waitForConnectionAndConsumeHeader() { + try { + FakeWindmillGrpcService.GetDataStreamInfo info = fakeService.waitForConnectedGetDataStream(); + Windmill.StreamingGetDataRequest request = info.requests.take(); + errorCollector.checkThat(request.getHeader(), Matchers.is(TEST_JOB_HEADER)); + assertEquals(0, request.getRequestIdCount()); + assertEquals(0, request.getComputationHeartbeatRequestCount()); + return info; + } catch (InterruptedException e) { + throw new RuntimeException(e); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 9a9f93d5c6c9..3c3e2a6579f6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -237,7 +236,8 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { } @Test - public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { + public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() + throws InterruptedException { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); @@ -250,7 +250,9 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { assertTrue(streamFactory.streamRegistry().contains(stream)); stream.halfClose(); - assertFalse(streamFactory.streamRegistry().contains(stream)); + while (streamFactory.streamRegistry().contains(stream)) { + Thread.sleep(100); + } } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 899d99b50352..e52b6e8de4bf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -117,7 +117,11 @@ public class GrpcWindmillServerTest { private final long clientId = 10L; private final Set openedChannels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + + @Rule + public transient Timeout globalTimeout = + Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build(); + @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); private Server server; @@ -482,7 +486,6 @@ public void onCompleted() { } @Test - @SuppressWarnings("FutureReturnValueIgnored") public void testStreamingGetData() throws Exception { // This server responds to GetDataRequests with responses that mirror the requests. serviceRegistry.addService( @@ -623,7 +626,7 @@ private void flushResponse() { for (int i = 0; i < 100; ++i) { final String key = "key" + i; final String s = i % 5 == 0 ? largeString(i) : "tag"; - executor.submit( + executor.execute( () -> { try { errorCollector.checkThat(