diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
index 968664a9b874..cf46e1f984dc 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java
@@ -26,8 +26,9 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
-import java.util.function.Supplier;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
@@ -37,6 +38,7 @@
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.checkerframework.checker.nullness.qual.NonNull;
import org.joda.time.DateTime;
import org.joda.time.Instant;
import org.slf4j.Logger;
@@ -48,12 +50,12 @@
* stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a
* broken stream.
*
- *
Subclasses should override {@link #onResponse(ResponseT)} to handle responses from the server,
- * and {@link #onNewStream()} to perform any work that must be done when a new stream is created,
- * such as sending headers or retrying requests.
+ *
Subclasses should override {@link #newResponseHandler()} to implement a handler for physical
+ * stream connection. {@link #onNewStream()} to perform any work that must be done when a new stream
+ * is created, such as sending headers or retrying requests.
*
- *
{@link #trySend(RequestT)} and {@link #startStream()} should not be called from {@link
- * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead.
+ *
{@link #trySend(RequestT)} and {@link #startStream()} should not be called when handling
+ * responses; use {@link #executeSafely(Runnable)} instead.
*
*
Synchronization on this is used to synchronize the gRpc stream state and internal data
* structures. Since grpc channel operations may block, synchronization on this stream may also
@@ -83,9 +85,12 @@ public abstract class AbstractWindmillStream implements Win
private final Set> streamRegistry;
private final int logEveryNStreamFailures;
private final String backendWorkerToken;
+
+ private final Function, TerminatingStreamObserver>
+ physicalStreamFactory;
+ protected final long physicalStreamDeadlineSeconds;
private final ResettableThrowingStreamObserver requestObserver;
- private final Supplier> requestObserverFactory;
private final StreamDebugMetrics debugMetrics;
private final AtomicBoolean isHealthCheckScheduled;
@@ -95,6 +100,17 @@ public abstract class AbstractWindmillStream implements Win
@GuardedBy("this")
protected boolean isShutdown;
+ // The active physical grpc stream. trySend will send messages on the bi-directional stream
+ // associated with this handler. The instances are created by subclasses via newResponseHandler.
+ // Subclasses may wish to store additional per-physical stream state within the handler.
+ @GuardedBy("this")
+ protected @Nullable PhysicalStreamHandler currentPhysicalStream;
+
+ // Generally the same as currentPhysicalStream, set under synchronization of this but can be read
+ // without.
+ private final AtomicReference currentPhysicalStreamForDebug =
+ new AtomicReference<>();
+
@GuardedBy("this")
private boolean started;
@@ -108,6 +124,9 @@ protected AbstractWindmillStream(
int logEveryNStreamFailures,
String backendWorkerToken) {
this.backendWorkerToken = backendWorkerToken;
+ this.physicalStreamFactory =
+ (StreamObserver observer) -> streamObserverFactory.from(clientFactory, observer);
+ this.physicalStreamDeadlineSeconds = streamObserverFactory.getDeadlineSeconds();
this.executor =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
@@ -124,10 +143,6 @@ protected AbstractWindmillStream(
this.finishLatch = new CountDownLatch(1);
this.logger = logger;
this.requestObserver = new ResettableThrowingStreamObserver<>(logger);
- this.requestObserverFactory =
- () ->
- streamObserverFactory.from(
- clientFactory, new AbstractWindmillStream.ResponseObserver());
this.sleeper = Sleeper.DEFAULT;
this.debugMetrics = StreamDebugMetrics.create();
}
@@ -138,19 +153,45 @@ private static String createThreadName(String streamType, String backendWorkerTo
: String.format("%s-WindmillStream-thread", streamType);
}
- /** Called on each response from the server. */
- protected abstract void onResponse(ResponseT response);
+ /** Represents a physical grpc stream that is part of the logical windmill stream. */
+ protected abstract class PhysicalStreamHandler {
- /** Called when a new underlying stream to the server has been opened. */
- protected abstract void onNewStream() throws WindmillStreamShutdownException;
+ /** Called on each response from the server. */
+ public abstract void onResponse(ResponseT response);
+
+ /** Returns whether there are any pending requests that should be retried on a stream break. */
+ public abstract boolean hasPendingRequests();
- /** Returns whether there are any pending requests that should be retried on a stream break. */
- protected abstract boolean hasPendingRequests();
+ /**
+ * Called when the physical stream has finished. For streams with requests that should be
+ * retried, requests should be moved to parent state so that it is captured by the next
+ * flushPendingToStream call.
+ */
+ public abstract void onDone(Status status);
+
+ /**
+ * Renders information useful for debugging as html.
+ *
+ * @implNote Don't require synchronization on AbstractWindmillStream.this, see the {@link
+ * #appendSummaryHtml(PrintWriter)} comment.
+ */
+ public abstract void appendHtml(PrintWriter writer);
+
+ private final StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.create();
+ }
+
+ protected abstract PhysicalStreamHandler newResponseHandler();
+
+ protected abstract void onNewStream() throws WindmillStreamShutdownException;
/** Try to send a request to the server. Returns true if the request was successfully sent. */
@CanIgnoreReturnValue
protected final synchronized boolean trySend(RequestT request)
throws WindmillStreamShutdownException {
+ if (currentPhysicalStream == null) {
+ return false;
+ }
+ currentPhysicalStream.streamDebugMetrics.recordSend();
debugMetrics.recordSend();
try {
requestObserver.onNext(request);
@@ -182,10 +223,14 @@ private void startStream() {
// Add the stream to the registry after it has been fully constructed.
streamRegistry.add(this);
while (true) {
+ @NonNull PhysicalStreamHandler streamHandler = newResponseHandler();
try {
synchronized (this) {
debugMetrics.recordStart();
- requestObserver.reset(requestObserverFactory.get());
+ streamHandler.streamDebugMetrics.recordStart();
+ currentPhysicalStream = streamHandler;
+ currentPhysicalStreamForDebug.set(currentPhysicalStream);
+ requestObserver.reset(physicalStreamFactory.apply(new ResponseObserver(streamHandler)));
onNewStream();
if (clientClosed) {
halfClose();
@@ -272,6 +317,23 @@ public final void maybeScheduleHealthCheck(Instant lastSendThreshold) {
*/
public final void appendSummaryHtml(PrintWriter writer) {
appendSpecificHtml(writer);
+
+ @Nullable PhysicalStreamHandler currentHandler = currentPhysicalStreamForDebug.get();
+ if (currentHandler != null) {
+ writer.format("Physical stream: ");
+ currentHandler.appendHtml(writer);
+ StreamDebugMetrics.Snapshot summaryMetrics =
+ currentHandler.streamDebugMetrics.getSummaryMetrics();
+ if (summaryMetrics.isClientClosed()) {
+ writer.write(" client closed");
+ }
+ writer.format(
+ " current stream is %dms old, last send %dms, last response %dms\n",
+ summaryMetrics.streamAge(),
+ summaryMetrics.timeSinceLastSend(),
+ summaryMetrics.timeSinceLastResponse());
+ }
+
StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics();
summaryMetrics
.restartMetrics()
@@ -304,6 +366,8 @@ public final void appendSummaryHtml(PrintWriter writer) {
}
/**
+ * Add specific debug state for the logical stream.
+ *
* @implNote Don't require synchronization on stream, see the {@link
* #appendSummaryHtml(PrintWriter)} comment.
*/
@@ -315,6 +379,9 @@ public final synchronized void halfClose() {
debugMetrics.recordHalfClose();
clientClosed = true;
try {
+ if (currentPhysicalStream != null) {
+ currentPhysicalStream.streamDebugMetrics.recordHalfClose();
+ }
requestObserver.onCompleted();
} catch (ResettableThrowingStreamObserver.StreamClosedException e) {
logger.warn("Stream was previously closed.");
@@ -354,11 +421,17 @@ public final void shutdown() {
}
}
- protected abstract void shutdownInternal();
+ protected synchronized void shutdownInternal() {}
/** Returns true if the stream was torn down and should not be restarted internally. */
- private synchronized boolean maybeTearDownStream() {
- if (isShutdown || (clientClosed && !hasPendingRequests())) {
+ private synchronized boolean maybeTearDownStream(PhysicalStreamHandler doneStream) {
+ if (clientClosed && !doneStream.hasPendingRequests()) {
+ shutdown();
+ }
+
+ if (isShutdown) {
+ // Once we have background closing physicalStreams we will need to improve this to wait for
+ // all of the work of the logical stream to be complete.
streamRegistry.remove(AbstractWindmillStream.this);
finishLatch.countDown();
executor.shutdownNow();
@@ -369,23 +442,49 @@ private synchronized boolean maybeTearDownStream() {
}
private class ResponseObserver implements StreamObserver {
+ private final PhysicalStreamHandler handler;
+
+ ResponseObserver(PhysicalStreamHandler handler) {
+ this.handler = handler;
+ }
@Override
public void onNext(ResponseT response) {
backoff.reset();
debugMetrics.recordResponse();
- onResponse(response);
+ handler.streamDebugMetrics.recordResponse();
+ handler.onResponse(response);
}
@Override
public void onError(Throwable t) {
- if (maybeTearDownStream()) {
- return;
- }
+ executeSafely(() -> onPhysicalStreamCompletion(Status.fromThrowable(t), handler));
+ }
- Status errorStatus = Status.fromThrowable(t);
- recordStreamStatus(errorStatus);
+ @Override
+ public void onCompleted() {
+ executeSafely(() -> onPhysicalStreamCompletion(OK_STATUS, handler));
+ }
+ }
+
+ @SuppressWarnings("nullness")
+ private void clearPhysicalStreamForDebug() {
+ currentPhysicalStreamForDebug.set(null);
+ }
+ private void onPhysicalStreamCompletion(Status status, PhysicalStreamHandler handler) {
+ synchronized (this) {
+ if (currentPhysicalStream == handler) {
+ clearPhysicalStreamForDebug();
+ currentPhysicalStream = null;
+ }
+ }
+ handler.onDone(status);
+ if (maybeTearDownStream(handler)) {
+ return;
+ }
+ // Backoff on errors.;
+ if (!status.isOk()) {
try {
long sleep = backoff.nextBackOffMillis();
debugMetrics.recordSleep(sleep);
@@ -394,54 +493,43 @@ public void onError(Throwable t) {
Thread.currentThread().interrupt();
return;
}
-
- executeSafely(AbstractWindmillStream.this::startStream);
- }
-
- @Override
- public void onCompleted() {
- if (maybeTearDownStream()) {
- return;
- }
- recordStreamStatus(OK_STATUS);
- executeSafely(AbstractWindmillStream.this::startStream);
}
+ recordStreamRestart(status);
+ startStream();
+ }
- private void recordStreamStatus(Status status) {
- int currentRestartCount = debugMetrics.incrementAndGetRestarts();
- if (status.isOk()) {
- String restartReason =
- "Stream completed successfully but did not complete requested operations, "
- + "recreating";
- logger.warn(restartReason);
- debugMetrics.recordRestartReason(restartReason);
- } else {
- int currentErrorCount = debugMetrics.incrementAndGetErrors();
- debugMetrics.recordRestartReason(status.toString());
- Throwable t = status.getCause();
- if (t instanceof StreamObserverCancelledException) {
- logger.error(
- "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}",
- getClass(),
- backendWorkerToken,
- t.getStackTrace(),
- t);
- } else if (currentRestartCount % logEveryNStreamFailures == 0) {
- // Don't log every restart since it will get noisy, and many errors transient.
- long nowMillis = Instant.now().getMillis();
- logger.debug(
- "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}"
- + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.",
- AbstractWindmillStream.this.getClass(),
- currentRestartCount,
- currentErrorCount,
- t,
- status,
- nowMillis - debugMetrics.getStartTimeMs(),
- debugMetrics
- .responseDebugString(nowMillis)
- .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING));
- }
+ private void recordStreamRestart(Status status) {
+ int currentRestartCount = debugMetrics.incrementAndGetRestarts();
+ if (status.isOk()) {
+ String restartReason =
+ "Stream completed successfully but did not complete requested operations, "
+ + "recreating";
+ logger.warn(restartReason);
+ debugMetrics.recordRestartReason(restartReason);
+ } else {
+ int currentErrorCount = debugMetrics.incrementAndGetErrors();
+ debugMetrics.recordRestartReason(status.toString());
+ Throwable t = status.getCause();
+ if (t instanceof StreamObserverCancelledException) {
+ logger.error(
+ "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}",
+ getClass(),
+ backendWorkerToken,
+ t.getStackTrace(),
+ t);
+ } else if (currentRestartCount % logEveryNStreamFailures == 0) {
+ // Don't log every restart since it will get noisy, and many errors transient.
+ long nowMillis = Instant.now().getMillis();
+ logger.debug(
+ "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}"
+ + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.",
+ AbstractWindmillStream.this.getClass(),
+ currentRestartCount,
+ currentErrorCount,
+ t,
+ status,
+ nowMillis - debugMetrics.getStartTimeMs(),
+ debugMetrics.responseDebugString(nowMillis).orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING));
}
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java
index 35d6fb5ea100..11690ecf279f 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java
@@ -46,13 +46,15 @@ final class AppendableInputStream extends InputStream {
private final AtomicLong blockedStartMs;
private final BlockingDeque queue;
private final InputStream stream;
+ private final long deadlineSeconds;
- AppendableInputStream() {
+ AppendableInputStream(long deadlineSeconds) {
this.cancelled = new AtomicBoolean(false);
this.complete = new AtomicBoolean(false);
this.blockedStartMs = new AtomicLong();
this.queue = new LinkedBlockingDeque<>(QUEUE_MAX_CAPACITY);
this.stream = new SequenceInputStream(new InputStreamEnumeration());
+ this.deadlineSeconds = deadlineSeconds;
}
long getBlockedStartMs() {
@@ -71,6 +73,10 @@ int size() {
return queue.size();
}
+ long getDeadlineSeconds() {
+ return deadlineSeconds;
+ }
+
/** Appends a new InputStream to the tail of this stream. */
synchronized void append(InputStream chunk) {
try {
@@ -155,7 +161,7 @@ public boolean hasMoreElements() {
try {
blockedStartMs.set(Instant.now().getMillis());
- current = queue.poll(180, TimeUnit.SECONDS);
+ current = queue.poll(deadlineSeconds, TimeUnit.SECONDS);
if (current != null && current != POISON_PILL) {
return true;
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
index 5d6f965e2d22..7a7b1a5cd27e 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
import com.google.auto.value.AutoValue;
import java.io.PrintWriter;
@@ -44,6 +45,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue;
@@ -57,7 +59,18 @@ final class GrpcCommitWorkStream
private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE;
- private final ConcurrentMap pending;
+ private static class StreamAndRequest {
+ StreamAndRequest(@Nullable CommitWorkPhysicalStreamHandler handler, PendingRequest request) {
+ this.handler = handler;
+ this.request = request;
+ }
+
+ final @Nullable CommitWorkPhysicalStreamHandler handler;
+ final PendingRequest request;
+ }
+
+ private final ConcurrentMap pending = new ConcurrentHashMap<>();
+
private final AtomicLong idGenerator;
private final JobHeader jobHeader;
private final int streamingRpcBatchLimit;
@@ -82,7 +95,6 @@ private GrpcCommitWorkStream(
streamRegistry,
logEveryNStreamFailures,
backendWorkerToken);
- pending = new ConcurrentHashMap<>();
this.idGenerator = idGenerator;
this.jobHeader = jobHeader;
this.streamingRpcBatchLimit = streamingRpcBatchLimit;
@@ -119,12 +131,20 @@ public void appendSpecificHtml(PrintWriter writer) {
@Override
protected synchronized void onNewStream() throws WindmillStreamShutdownException {
trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build());
+ // Flush all pending requests that are no longer on active streams.
try (Batcher resendBatcher = new Batcher()) {
- for (Map.Entry entry : pending.entrySet()) {
- if (!resendBatcher.canAccept(entry.getValue().getBytes())) {
+ for (Map.Entry entry : pending.entrySet()) {
+ CommitWorkPhysicalStreamHandler requestHandler = entry.getValue().handler;
+ checkState(requestHandler != currentPhysicalStream);
+ // When we have streams closing in the background we should avoid retrying the requests
+ // active on those streams.
+
+ long id = entry.getKey();
+ PendingRequest request = entry.getValue().request;
+ if (!resendBatcher.canAccept(request.getBytes())) {
resendBatcher.flush();
}
- resendBatcher.add(entry.getKey(), entry.getValue());
+ resendBatcher.add(id, request);
}
}
}
@@ -139,42 +159,38 @@ public CommitWorkStream.RequestBatcher batcher() {
}
@Override
- protected boolean hasPendingRequests() {
- return !pending.isEmpty();
- }
-
- @Override
- protected void sendHealthCheck() throws WindmillStreamShutdownException {
- if (hasPendingRequests()) {
+ protected synchronized void sendHealthCheck() throws WindmillStreamShutdownException {
+ if (currentPhysicalStream != null && currentPhysicalStream.hasPendingRequests()) {
StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder();
builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID);
trySend(builder.build());
}
}
- @Override
- protected void onResponse(StreamingCommitResponse response) {
- CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler();
- for (int i = 0; i < response.getRequestIdCount(); ++i) {
- long requestId = response.getRequestId(i);
- if (requestId == HEARTBEAT_REQUEST_ID) {
- continue;
- }
+ private class CommitWorkPhysicalStreamHandler extends PhysicalStreamHandler {
+ @Override
+ public void onResponse(StreamingCommitResponse response) {
+ CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler();
+ for (int i = 0; i < response.getRequestIdCount(); ++i) {
+ long requestId = response.getRequestId(i);
+ if (requestId == HEARTBEAT_REQUEST_ID) {
+ continue;
+ }
+
+ // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may
+ // be omitted.
+ CommitStatus commitStatus =
+ i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK;
- // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may
- // be omitted.
- CommitStatus commitStatus =
- i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK;
-
- @Nullable PendingRequest pendingRequest = pending.remove(requestId);
- if (pendingRequest == null) {
- synchronized (this) {
- if (!isShutdown) {
- // Missing responses are expected after shutdown() because it removes them.
- LOG.error("Got unknown commit request ID: {}", requestId);
- }
+ @Nullable StreamAndRequest entry = pending.remove(requestId);
+ if (entry == null) {
+ LOG.error("Got unknown commit request ID: {}", requestId);
+ continue;
}
- } else {
+ if (entry.handler != this) {
+ LOG.error("Got commit request id {} on unexpected stream", requestId);
+ }
+ PendingRequest pendingRequest = entry.request;
try {
pendingRequest.completeWithStatus(commitStatus);
} catch (RuntimeException e) {
@@ -185,16 +201,40 @@ protected void onResponse(StreamingCommitResponse response) {
failureHandler.addError(commitStatus, e);
}
}
+
+ failureHandler.throwIfNonEmpty();
+ }
+
+ @Override
+ public boolean hasPendingRequests() {
+ return pending.entrySet().stream().anyMatch(e -> e.getValue().handler == this);
}
- failureHandler.throwIfNonEmpty();
+ @Override
+ public void onDone(Status status) {
+ if (status.isOk() && hasPendingRequests()) {
+ LOG.warn("Unexpected requests without responses on drained physical stream, retrying.");
+ }
+ }
+
+ @Override
+ public void appendHtml(PrintWriter writer) {
+ writer.format(
+ "CommitWorkStream: %d pending",
+ pending.entrySet().stream().filter(e -> e.getValue().handler == this).count());
+ }
+ }
+
+ @Override
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new CommitWorkPhysicalStreamHandler();
}
@Override
- protected void shutdownInternal() {
- Iterator pendingRequests = pending.values().iterator();
+ protected synchronized void shutdownInternal() {
+ Iterator pendingRequests = pending.values().iterator();
while (pendingRequests.hasNext()) {
- PendingRequest pendingRequest = pendingRequests.next();
+ PendingRequest pendingRequest = pendingRequests.next().request;
pendingRequest.abort();
pendingRequests.remove();
}
@@ -230,10 +270,14 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest)
.setSerializedWorkItemCommit(pendingRequest.serializedCommit());
StreamingCommitWorkRequest chunk = requestBuilder.build();
synchronized (this) {
- if (!prepareForSend(id, pendingRequest)) {
+ if (isShutdown) {
pendingRequest.abort();
return;
}
+ pending.put(
+ id,
+ new StreamAndRequest(
+ (CommitWorkPhysicalStreamHandler) currentPhysicalStream, pendingRequest));
trySend(chunk);
}
}
@@ -256,10 +300,16 @@ private void issueBatchedRequest(Map requests)
}
StreamingCommitWorkRequest request = requestBuilder.build();
synchronized (this) {
- if (!prepareForSend(requests)) {
+ if (isShutdown) {
requests.forEach((ignored, pendingRequest) -> pendingRequest.abort());
return;
}
+ for (Map.Entry entry : requests.entrySet()) {
+ pending.put(
+ entry.getKey(),
+ new StreamAndRequest(
+ (CommitWorkPhysicalStreamHandler) currentPhysicalStream, entry.getValue()));
+ }
trySend(request);
}
}
@@ -269,11 +319,15 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest)
checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId.");
ByteString serializedCommit = pendingRequest.serializedCommit();
synchronized (this) {
- if (!prepareForSend(id, pendingRequest)) {
+ if (isShutdown) {
pendingRequest.abort();
return;
}
+ pending.put(
+ id,
+ new StreamAndRequest(
+ (CommitWorkPhysicalStreamHandler) currentPhysicalStream, pendingRequest));
for (int i = 0;
i < serializedCommit.size();
i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
@@ -300,24 +354,6 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest)
}
}
- /** Returns true if prepare for send succeeded. */
- private synchronized boolean prepareForSend(long id, PendingRequest request) {
- if (!isShutdown) {
- pending.put(id, request);
- return true;
- }
- return false;
- }
-
- /** Returns true if prepare for send succeeded. */
- private synchronized boolean prepareForSend(Map requests) {
- if (!isShutdown) {
- pending.putAll(requests);
- return true;
- }
- return false;
- }
-
@AutoValue
abstract static class PendingRequest {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
index f7960b7f68dc..938ec1c693c7 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -46,6 +46,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -80,15 +81,6 @@ final class GrpcDirectGetWorkStream
private final GetDataClient getDataClient;
private final AtomicReference lastRequest;
- /**
- * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they
- * come in. Once all chunks for a response has been received, the chunk is processed and the
- * buffer is cleared.
- *
- * @implNote Buffers are not persisted across stream restarts.
- */
- private final ConcurrentMap workItemAssemblers;
-
private final boolean requestBatchedGetWorkResponse;
private GrpcDirectGetWorkStream(
@@ -118,7 +110,6 @@ private GrpcDirectGetWorkStream(
backendWorkerToken);
this.requestHeader = requestHeader;
this.workItemScheduler = workItemScheduler;
- this.workItemAssemblers = new ConcurrentHashMap<>();
this.heartbeatSender = heartbeatSender;
this.workCommitter = workCommitter;
this.getDataClient = getDataClient;
@@ -199,9 +190,47 @@ private void maybeSendRequestExtension(GetWorkBudget extension) {
}
}
+ private class DirectGetWorkPhysicalStreamHandler extends PhysicalStreamHandler {
+ /**
+ * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they
+ * come in. Once all chunks for a response has been received, the chunk is processed and the
+ * buffer is cleared.
+ *
+ * @implNote Buffers are not persisted across stream restarts.
+ */
+ final ConcurrentMap workItemAssemblers =
+ new ConcurrentHashMap<>();
+
+ @Override
+ public void onResponse(StreamingGetWorkResponseChunk response) {
+ workItemAssemblers
+ .computeIfAbsent(response.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
+ .append(response)
+ .forEach(GrpcDirectGetWorkStream.this::consumeAssembledWorkItem);
+ }
+
+ @Override
+ public boolean hasPendingRequests() {
+ return false;
+ }
+
+ @Override
+ public void onDone(Status status) {}
+
+ @Override
+ public void appendHtml(PrintWriter writer) {
+ // Number of buffers is same as distinct workers that sent work on this stream.
+ writer.format("%d buffers", workItemAssemblers.size());
+ }
+ }
+
+ @Override
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new DirectGetWorkPhysicalStreamHandler();
+ }
+
@Override
protected synchronized void onNewStream() throws WindmillStreamShutdownException {
- workItemAssemblers.clear();
budgetTracker.reset();
GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension();
StreamingGetWorkRequest request =
@@ -219,17 +248,9 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
trySend(request);
}
- @Override
- protected boolean hasPendingRequests() {
- return false;
- }
-
@Override
public void appendSpecificHtml(PrintWriter writer) {
- // Number of buffers is same as distinct workers that sent work on this stream.
- writer.format(
- "GetWorkStream: %d buffers, " + "last sent request: %s; ",
- workItemAssemblers.size(), lastRequest.get());
+ writer.format("GetWorkStream: last sent request: %s; ", lastRequest.get());
writer.print(budgetTracker.debugString());
}
@@ -238,17 +259,6 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException {
trySend(HEALTH_CHECK_REQUEST);
}
- @Override
- protected void shutdownInternal() {}
-
- @Override
- protected void onResponse(StreamingGetWorkResponseChunk chunk) {
- workItemAssemblers
- .computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
- .append(chunk)
- .forEach(this::consumeAssembledWorkItem);
- }
-
private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
WorkItem workItem = assembledWorkItem.workItem();
GetWorkResponseChunkAssembler.ComputationMetadata metadata =
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index 2ae58a60e7b1..7de074122a3c 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -18,7 +18,9 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verifyNotNull;
import java.io.IOException;
import java.io.InputStream;
@@ -34,7 +36,9 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
@@ -55,7 +59,12 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.BackOffUtils;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.Sleeper;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
+import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,10 +77,21 @@ final class GrpcGetDataStream
private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST =
StreamingGetDataRequest.newBuilder().build();
- /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */
+ static final FluentBackoff BACK_OFF_FACTORY =
+ FluentBackoff.DEFAULT
+ .withInitialBackoff(Duration.millis(10))
+ .withMaxBackoff(Duration.standardSeconds(10));
+
+ /**
+ * @implNote {@link QueuedBatch} objects in the queue should also be guarded by {@code this}.
+ * Batches should be sent from the front of the queue and only removed from the queue once
+ * added to the pending set of a physical stream.
+ */
+ @GuardedBy("this")
private final Deque batches;
- private final Map pending;
+ private final Supplier batchesDebugSizeSupplier;
+
private final AtomicLong idGenerator;
private final JobHeader jobHeader;
private final int streamingRpcBatchLimit;
@@ -105,8 +125,11 @@ private GrpcGetDataStream(
this.idGenerator = idGenerator;
this.jobHeader = jobHeader;
this.streamingRpcBatchLimit = streamingRpcBatchLimit;
- this.batches = new ConcurrentLinkedDeque<>();
- this.pending = new ConcurrentHashMap<>();
+ // A concurrent deque is used so that we can observe the size without synchronization on "this".
+ // Otherwise the deque is accessed via batches which has a guardedby annotation.
+ ConcurrentLinkedDeque batches = new ConcurrentLinkedDeque<>();
+ this.batches = batches;
+ this.batchesDebugSizeSupplier = batches::size;
this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
this.processHeartbeatResponses = processHeartbeatResponses;
}
@@ -153,45 +176,109 @@ private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest)
trySend(getDataRequest);
}
- @Override
- protected synchronized void onNewStream() throws WindmillStreamShutdownException {
- trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build());
- if (clientClosed) {
- // We rely on close only occurring after all methods on the stream have returned.
- // Since the requestKeyedData and requestGlobalData methods are blocking this
- // means there should be no pending requests.
- verify(!hasPendingRequests(), "Pending requests not expected if we've half-closed.");
- } else {
+ class GetDataPhysicalStreamHandler extends PhysicalStreamHandler {
+ private final ConcurrentHashMap pending =
+ new ConcurrentHashMap<>();
+
+ public void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException {
+ // Synchronization of pending inserts is necessary with send to ensure duplicates are not
+ // sent on stream reconnect.
+ for (QueuedRequest request : batch.requestsReadOnly()) {
+ boolean alreadyPresent = pending.put(request.id(), request.getResponseStream()) != null;
+ verify(!alreadyPresent, "Request already sent, id: %s", request.id());
+ }
+
+ if (!trySend(batch.asGetDataRequest())) {
+ // The stream broke before this call went through; onNewStream will retry the fetch.
+ LOG.debug("GetData stream broke before call started.");
+ }
+ }
+
+ @Override
+ public void onResponse(StreamingGetDataResponse chunk) {
+ checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount());
+ checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1);
+ onHeartbeatResponse(chunk.getComputationHeartbeatResponseList());
+
+ for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
+ long requestId = chunk.getRequestId(i);
+ boolean completeResponse = chunk.getRemainingBytesForResponse() == 0;
+ AppendableInputStream responseStream =
+ verifyNotNull(
+ completeResponse ? pending.remove(requestId) : pending.get(requestId),
+ "No pending response stream");
+ responseStream.append(chunk.getSerializedResponse(i).newInput());
+ if (completeResponse) {
+ responseStream.complete();
+ }
+ }
+ }
+
+ @Override
+ public boolean hasPendingRequests() {
+ return !pending.isEmpty();
+ }
+
+ @Override
+ public void onDone(Status status) {
+ if (status.isOk() && hasPendingRequests()) {
+ LOG.warn("Pending requests not expected on successful GetData stream flushing.");
+ }
for (AppendableInputStream responseStream : pending.values()) {
responseStream.cancel();
}
+ pending.clear();
+ }
+
+ @Override
+ public void appendHtml(PrintWriter writer) {
+ writer.format("%d pending requests [", pending.size());
+ for (Map.Entry entry : pending.entrySet()) {
+ writer.format("Stream %d ", entry.getKey());
+ if (entry.getValue().isCancelled()) {
+ writer.append("cancelled ");
+ }
+ if (entry.getValue().isComplete()) {
+ writer.append("complete ");
+ }
+ int queueSize = entry.getValue().size();
+ if (queueSize > 0) {
+ writer.format("%d queued responses ", queueSize);
+ }
+ long blockedMs = entry.getValue().getBlockedStartMs();
+ if (blockedMs > 0) {
+ writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs);
+ }
+ }
+ writer.append("]");
}
}
@Override
- protected boolean hasPendingRequests() {
- return !pending.isEmpty() || !batches.isEmpty();
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new GetDataPhysicalStreamHandler();
}
@Override
- protected void onResponse(StreamingGetDataResponse chunk) {
- checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount());
- checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1);
- onHeartbeatResponse(chunk.getComputationHeartbeatResponseList());
-
- for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
- @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i));
- if (responseStream == null) {
- synchronized (this) {
- // shutdown()/shutdownInternal() cleans up pending, else we expect a pending
- // responseStream for every response.
- verify(isShutdown, "No pending response stream");
- }
- continue;
- }
- responseStream.append(chunk.getSerializedResponse(i).newInput());
- if (chunk.getRemainingBytesForResponse() == 0) {
- responseStream.complete();
+ protected synchronized void onNewStream() throws WindmillStreamShutdownException {
+ trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build());
+ while (!batches.isEmpty()) {
+ QueuedBatch batch = checkNotNull(batches.peekFirst());
+ verify(!batch.isEmpty());
+ if (!batch.isFinalized()) break;
+ try {
+ verify(
+ batch == batches.pollFirst(),
+ "Sent GetDataStream request batch removed before send() was complete.");
+ checkNotNull((GetDataPhysicalStreamHandler) currentPhysicalStream).sendBatch(batch);
+ // Notify all waiters with requests in this batch as well as the sender
+ // of the next batch (if one exists).
+ batch.notifySent();
+ } catch (Exception e) {
+ LOG.debug("Batch failed to send on new stream", e);
+ // Free waiters if the send() failed.
+ batch.notifyFailed();
+ throw e;
}
}
}
@@ -204,14 +291,17 @@ private long uniqueId() {
public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request)
throws WindmillStreamShutdownException {
return issueRequest(
- QueuedRequest.forComputation(uniqueId(), computation, request),
+ QueuedRequest.forComputation(
+ uniqueId(), computation, request, physicalStreamDeadlineSeconds),
KeyedGetDataResponse::parseFrom);
}
@Override
public GlobalData requestGlobalData(GlobalDataRequest request)
throws WindmillStreamShutdownException {
- return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom);
+ return issueRequest(
+ QueuedRequest.global(uniqueId(), request, physicalStreamDeadlineSeconds),
+ GlobalData::parseFrom);
}
@Override
@@ -284,8 +374,8 @@ public void onHeartbeatResponse(List resp
}
@Override
- protected void sendHealthCheck() throws WindmillStreamShutdownException {
- if (hasPendingRequests()) {
+ protected synchronized void sendHealthCheck() throws WindmillStreamShutdownException {
+ if (currentPhysicalStream != null && currentPhysicalStream.hasPendingRequests()) {
trySend(HEALTH_CHECK_REQUEST);
}
}
@@ -294,8 +384,14 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException {
protected synchronized void shutdownInternal() {
// Stream has been explicitly closed. Drain pending input streams and request batches.
// Future calls to send RPCs will fail.
- pending.values().forEach(AppendableInputStream::cancel);
- pending.clear();
+ final @Nullable GetDataPhysicalStreamHandler currentGetDataStream =
+ (GetDataPhysicalStreamHandler) currentPhysicalStream;
+ if (currentGetDataStream != null) {
+ for (AppendableInputStream ais : currentGetDataStream.pending.values()) {
+ ais.cancel();
+ }
+ currentGetDataStream.pending.clear();
+ }
batches.forEach(
batch -> {
batch.markFinalized();
@@ -306,30 +402,12 @@ protected synchronized void shutdownInternal() {
@Override
public void appendSpecificHtml(PrintWriter writer) {
- writer.format(
- "GetDataStream: %d queued batches, %d pending requests [", batches.size(), pending.size());
- for (Map.Entry entry : pending.entrySet()) {
- writer.format("Stream %d ", entry.getKey());
- if (entry.getValue().isCancelled()) {
- writer.append("cancelled ");
- }
- if (entry.getValue().isComplete()) {
- writer.append("complete ");
- }
- int queueSize = entry.getValue().size();
- if (queueSize > 0) {
- writer.format("%d queued responses ", queueSize);
- }
- long blockedMs = entry.getValue().getBlockedStartMs();
- if (blockedMs > 0) {
- writer.format("blocked for %dms", Instant.now().getMillis() - blockedMs);
- }
- }
- writer.append("]");
+ writer.format("GetDataStream: %d queued batches", batchesDebugSizeSupplier.get());
}
private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn)
throws WindmillStreamShutdownException {
+ final BackOff backoff = BACK_OFF_FACTORY.backoff();
while (true) {
request.resetResponseStream();
try {
@@ -342,12 +420,15 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn {
ResponseT parse(InputStream input) throws IOException;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java
index ef7f5b20bb07..318738893f0d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java
@@ -51,25 +51,33 @@ static class QueuedRequest {
private final ComputationOrGlobalDataRequest dataRequest;
private AppendableInputStream responseStream;
- private QueuedRequest(long id, ComputationOrGlobalDataRequest dataRequest) {
+ private QueuedRequest(
+ long id, ComputationOrGlobalDataRequest dataRequest, long deadlineSeconds) {
this.id = id;
this.dataRequest = dataRequest;
- responseStream = new AppendableInputStream();
+ responseStream = new AppendableInputStream(deadlineSeconds);
}
static QueuedRequest forComputation(
- long id, String computation, KeyedGetDataRequest keyedGetDataRequest) {
+ long id,
+ String computation,
+ KeyedGetDataRequest keyedGetDataRequest,
+ long deadlineSeconds) {
ComputationGetDataRequest computationGetDataRequest =
ComputationGetDataRequest.newBuilder()
.setComputationId(computation)
.addRequests(keyedGetDataRequest)
.build();
return new QueuedRequest(
- id, ComputationOrGlobalDataRequest.computation(computationGetDataRequest));
+ id,
+ ComputationOrGlobalDataRequest.computation(computationGetDataRequest),
+ deadlineSeconds);
}
- static QueuedRequest global(long id, GlobalDataRequest globalDataRequest) {
- return new QueuedRequest(id, ComputationOrGlobalDataRequest.global(globalDataRequest));
+ static QueuedRequest global(
+ long id, GlobalDataRequest globalDataRequest, long deadlineSeconds) {
+ return new QueuedRequest(
+ id, ComputationOrGlobalDataRequest.global(globalDataRequest), deadlineSeconds);
}
static Comparator globalRequestsFirst() {
@@ -93,7 +101,7 @@ AppendableInputStream getResponseStream() {
}
void resetResponseStream() {
- this.responseStream = new AppendableInputStream();
+ this.responseStream = new AppendableInputStream(responseStream.getDeadlineSeconds());
}
public ComputationOrGlobalDataRequest getDataRequest() {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
index 63342b1f6124..a1c758eac446 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import java.io.PrintWriter;
-import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
@@ -35,6 +34,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,7 +52,6 @@ final class GrpcGetWorkStream
private final GetWorkRequest request;
private final WorkItemReceiver receiver;
- private final Map workItemAssemblers;
private final AtomicLong inflightMessages;
private final AtomicLong inflightBytes;
private final boolean requestBatchedGetWorkResponse;
@@ -81,7 +80,6 @@ private GrpcGetWorkStream(
backendWorkerToken);
this.request = request;
this.receiver = receiver;
- this.workItemAssemblers = new ConcurrentHashMap<>();
this.inflightMessages = new AtomicLong();
this.inflightBytes = new AtomicLong();
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
@@ -131,9 +129,41 @@ private void sendRequestExtension(long moreItems, long moreBytes) {
});
}
+ private class GetWorkPhysicalStreamHandler extends PhysicalStreamHandler {
+
+ private final ConcurrentHashMap workItemAssemblers =
+ new ConcurrentHashMap<>();
+
+ @Override
+ public void onResponse(StreamingGetWorkResponseChunk response) {
+ workItemAssemblers
+ .computeIfAbsent(response.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
+ .append(response)
+ .forEach(GrpcGetWorkStream.this::consumeAssembledWorkItem);
+ }
+
+ @Override
+ public boolean hasPendingRequests() {
+ return false;
+ }
+
+ @Override
+ public void onDone(Status status) {}
+
+ @Override
+ public void appendHtml(PrintWriter writer) {
+ // Number of buffers is same as distinct workers that sent work on this stream.
+ writer.format("%d buffers", workItemAssemblers.size());
+ }
+ }
+
+ @Override
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new GetWorkPhysicalStreamHandler();
+ }
+
@Override
protected synchronized void onNewStream() throws WindmillStreamShutdownException {
- workItemAssemblers.clear();
inflightMessages.set(request.getMaxItems());
inflightBytes.set(request.getMaxBytes());
trySend(
@@ -143,20 +173,11 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
.build());
}
- @Override
- protected void shutdownInternal() {}
-
- @Override
- protected boolean hasPendingRequests() {
- return false;
- }
-
@Override
public void appendSpecificHtml(PrintWriter writer) {
- // Number of buffers is same as distinct workers that sent work on this stream.
writer.format(
- "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed",
- workItemAssemblers.size(), inflightMessages.intValue(), inflightBytes.intValue());
+ "GetWorkStream: %d inflight messages allowed, %d inflight bytes allowed",
+ inflightMessages.intValue(), inflightBytes.intValue());
}
@Override
@@ -164,14 +185,6 @@ protected void sendHealthCheck() throws WindmillStreamShutdownException {
trySend(HEALTH_CHECK);
}
- @Override
- protected void onResponse(StreamingGetWorkResponseChunk chunk) {
- workItemAssemblers
- .computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
- .append(chunk)
- .forEach(this::consumeAssembledWorkItem);
- }
-
private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
receiver.receiveWork(
assembledWorkItem.computationMetadata().computationId(),
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
index cb30a3e6a500..9b99b3bda909 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java
@@ -32,6 +32,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -92,15 +93,6 @@ public static GrpcGetWorkerMetadataStream create(
serverMappingUpdater);
}
- /**
- * Each instance of {@link AbstractWindmillStream} owns its own responseObserver that calls
- * onResponse().
- */
- @Override
- protected void onResponse(WorkerMetadataResponse response) {
- extractWindmillEndpointsFrom(response).ifPresent(serverMappingConsumer);
- }
-
/**
* Acquires the {@link #metadataLock} Returns {@link Optional} if the
* metadataVersion in the response is not stale (older or equal to current {@link
@@ -127,16 +119,30 @@ private Optional extractWindmillEndpointsFrom(
}
@Override
- protected void onNewStream() throws WindmillStreamShutdownException {
- trySend(workerMetadataRequest);
- }
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new PhysicalStreamHandler() {
- @Override
- protected void shutdownInternal() {}
+ @Override
+ public void onResponse(WorkerMetadataResponse response) {
+ extractWindmillEndpointsFrom(response).ifPresent(serverMappingConsumer);
+ }
+
+ @Override
+ public boolean hasPendingRequests() {
+ return false;
+ }
+
+ @Override
+ public void onDone(Status status) {}
+
+ @Override
+ public void appendHtml(PrintWriter writer) {}
+ };
+ }
@Override
- protected boolean hasPendingRequests() {
- return false;
+ protected void onNewStream() throws WindmillStreamShutdownException {
+ trySend(workerMetadataRequest);
}
@Override
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java
index b4f3e854bf19..e97416c004fa 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java
@@ -37,6 +37,8 @@ public abstract TerminatingStreamObserver from(
Function, StreamObserver> clientFactory,
StreamObserver responseObserver);
+ public abstract long getDeadlineSeconds();
+
private static class Direct extends StreamObserverFactory {
private final long deadlineSeconds;
private final int messagesBetweenIsReadyChecks;
@@ -59,5 +61,10 @@ public TerminatingStreamObserver from(
return new DirectStreamObserver<>(
phaser, outboundObserver, deadlineSeconds, messagesBetweenIsReadyChecks);
}
+
+ @Override
+ public long getDeadlineSeconds() {
+ return deadlineSeconds;
+ }
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index bead1ffd5b24..c1696d8a70ab 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -4190,7 +4190,9 @@ public void processElement(ProcessContext c) {
static class TestExceptionFn extends DoFn {
- boolean firstTime = true;
+ // Note that the use of static works because this DoFn is only used in a single test. We need
+ // to use static as the DoFn is not cached after user-code exceptions.
+ static boolean firstTime = true;
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java
index d9a1397e1c7c..92c081591c73 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java
@@ -32,6 +32,7 @@
import java.util.function.Function;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.CallStreamObserver;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
@@ -157,16 +158,28 @@ private TestStream(
}
@Override
- protected void onResponse(Integer response) {}
+ protected PhysicalStreamHandler newResponseHandler() {
+ return new PhysicalStreamHandler() {
- @Override
- protected void onNewStream() {
- numStarts.incrementAndGet();
+ @Override
+ public void onResponse(Integer response) {}
+
+ @Override
+ public boolean hasPendingRequests() {
+ return false;
+ }
+
+ @Override
+ public void onDone(Status status) {}
+
+ @Override
+ public void appendHtml(PrintWriter writer) {}
+ };
}
@Override
- protected boolean hasPendingRequests() {
- return false;
+ protected void onNewStream() {
+ numStarts.incrementAndGet();
}
private void testSend() throws WindmillStreamShutdownException {
@@ -192,9 +205,6 @@ private void waitForHealthChecks(int expectedHealthChecks) {
@Override
protected void appendSpecificHtml(PrintWriter writer) {}
-
- @Override
- protected void shutdownInternal() {}
}
private static class TestCallStreamObserver extends CallStreamObserver {
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java
index fdd213223987..8cb831f15f62 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java
@@ -43,7 +43,11 @@ public class WindmillStreamPoolTest {
private final ConcurrentHashMap<
TestWindmillStream, WindmillStreamPool.StreamData>
holds = new ConcurrentHashMap<>();
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+
+ @Rule
+ public transient Timeout globalTimeout =
+ Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build();
+
private List> streams;
@Before
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java
new file mode 100644
index 000000000000..85c3c71663f1
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/FakeWindmillGrpcService.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.LinkedBlockingQueue;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
+import org.hamcrest.Matchers;
+import org.junit.rules.ErrorCollector;
+
+class FakeWindmillGrpcService
+ extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ private final ErrorCollector errorCollector;
+
+ @GuardedBy("this")
+ private boolean failOnNewStreams = false;
+
+ public FakeWindmillGrpcService(ErrorCollector errorCollector) {
+ this.errorCollector = errorCollector;
+ }
+
+ public static class StreamInfo {
+ public StreamInfo(StreamObserver responseObserver) {
+ this.responseObserver = responseObserver;
+ }
+
+ public final StreamObserver responseObserver;
+ public final BlockingQueue requests = new LinkedBlockingQueue<>(1000);
+ public final CompletableFuture onDone = new CompletableFuture<>();
+ };
+
+ private final BlockingQueue commitStreams = new LinkedBlockingQueue<>(1000);
+ private final BlockingQueue getDataStreams = new LinkedBlockingQueue<>(1000);
+
+ private static class StreamInfoObserver implements StreamObserver {
+ private final StreamInfo streamInfo;
+ private final ErrorCollector errorCollector;
+
+ public StreamInfoObserver(
+ StreamInfo streamInfo, ErrorCollector errorCollector) {
+ this.streamInfo = streamInfo;
+ this.errorCollector = errorCollector;
+ }
+
+ @Override
+ public void onNext(RequestT request) {
+ errorCollector.checkThat(streamInfo.requests.add(request), Matchers.is(true));
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ streamInfo.onDone.complete(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ streamInfo.onDone.complete(null);
+ }
+ }
+
+ public static class CommitStreamInfo
+ extends StreamInfo {
+ CommitStreamInfo(StreamObserver responseObserver) {
+ super(responseObserver);
+ }
+ }
+
+ @Override
+ public StreamObserver commitWorkStream(
+ StreamObserver responseObserver) {
+ CommitStreamInfo info = new CommitStreamInfo(responseObserver);
+ synchronized (this) {
+ errorCollector.checkThat(failOnNewStreams, Matchers.is(false));
+ errorCollector.checkThat(commitStreams.offer(info), Matchers.is(true));
+ }
+ return new StreamInfoObserver<>(info, errorCollector);
+ }
+
+ public CommitStreamInfo waitForConnectedCommitStream() throws InterruptedException {
+ return commitStreams.take();
+ }
+
+ public synchronized void expectNoMoreStreams() {
+ failOnNewStreams = true;
+ errorCollector.checkThat(commitStreams.isEmpty(), Matchers.is(true));
+ errorCollector.checkThat(getDataStreams.isEmpty(), Matchers.is(true));
+ }
+
+ public static class GetDataStreamInfo
+ extends StreamInfo {
+ GetDataStreamInfo(StreamObserver responseObserver) {
+ super(responseObserver);
+ }
+ }
+
+ @Override
+ public StreamObserver getDataStream(
+ StreamObserver responseObserver) {
+ GetDataStreamInfo info = new GetDataStreamInfo(responseObserver);
+ synchronized (this) {
+ errorCollector.checkThat(failOnNewStreams, Matchers.is(false));
+ errorCollector.checkThat(getDataStreams.offer(info), Matchers.is(true));
+ }
+ return new StreamInfoObserver<>(info, errorCollector);
+ }
+
+ public GetDataStreamInfo waitForConnectedGetDataStream() throws InterruptedException {
+ return getDataStreams.take();
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java
index 31a851a702ee..7619668220d4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java
@@ -18,18 +18,20 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static com.google.common.truth.Truth.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.argThat;
-import static org.mockito.Mockito.inOrder;
-import static org.mockito.Mockito.spy;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.atomic.AtomicReference;
-import javax.annotation.Nullable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
@@ -39,21 +41,20 @@
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.ServerCallStreamObserver;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.util.MutableHandlerRegistry;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ErrorCollector;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import org.mockito.InOrder;
@RunWith(JUnit4.class)
public class GrpcCommitWorkStreamTest {
+
private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest";
private static final Windmill.JobHeader TEST_JOB_HEADER =
Windmill.JobHeader.newBuilder()
@@ -63,10 +64,20 @@ public class GrpcCommitWorkStreamTest {
.build();
private static final String COMPUTATION_ID = "computationId";
+ @SuppressWarnings("InlineMeInliner") // inline `Strings.repeat()` - Java 11+ API only
+ private static final ByteString LARGE_BYTE_STRING =
+ ByteString.copyFromUtf8(Strings.repeat("a", 2 * 1024 * 1024));
+
+ @Rule public final ErrorCollector errorCollector = new ErrorCollector();
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
- private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+
+ @Rule
+ public transient Timeout globalTimeout =
+ Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build();
+
+ private final FakeWindmillGrpcService fakeService = new FakeWindmillGrpcService(errorCollector);
private ManagedChannel inProcessChannel;
+ private Server inProcessServer;
private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) {
return Windmill.WorkItemCommitRequest.newBuilder()
@@ -79,27 +90,25 @@ private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value)
@Before
public void setUp() throws IOException {
- Server server =
- InProcessServerBuilder.forName(FAKE_SERVER_NAME)
- .fallbackHandlerRegistry(serviceRegistry)
- .directExecutor()
- .build()
- .start();
-
+ inProcessServer =
+ grpcCleanup.register(
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .addService(fakeService)
+ .directExecutor()
+ .build()
+ .start());
inProcessChannel =
grpcCleanup.register(
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
- grpcCleanup.register(server);
- grpcCleanup.register(inProcessChannel);
}
@After
public void cleanUp() {
+ inProcessServer.shutdownNow();
inProcessChannel.shutdownNow();
}
- private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub testStub) {
- serviceRegistry.addService(testStub);
+ private GrpcCommitWorkStream createCommitWorkStream() {
GrpcCommitWorkStream commitWorkStream =
(GrpcCommitWorkStream)
GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
@@ -110,16 +119,12 @@ private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub tes
}
@Test
- public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+ public void testShutdown_abortsActiveCommits() throws InterruptedException, ExecutionException {
int numCommits = 5;
CountDownLatch commitProcessed = new CountDownLatch(numCommits);
Set onDone = new HashSet<>();
- TestCommitWorkStreamRequestObserver requestObserver =
- spy(new TestCommitWorkStreamRequestObserver());
- CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver);
- GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
- InOrder requestObserverVerifier = inOrder(requestObserver);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
batcher.commitWorkItem(
@@ -133,120 +138,298 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException {
} catch (StreamObserverCancelledException ignored) {
}
- // Verify that we sent the commits above in a request + the initial header.
- requestObserverVerifier
- .verify(requestObserver)
- .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER)));
- requestObserverVerifier
- .verify(requestObserver)
- .onNext(argThat(request -> !request.getCommitChunkList().isEmpty()));
- requestObserverVerifier.verifyNoMoreInteractions();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
+ // The next request should have some chunks.
+ assertThat(streamInfo.requests.take().getCommitChunkList()).isNotEmpty();
// We won't get responses so we will have some pending requests.
- assertTrue(commitWorkStream.hasPendingRequests());
+ assertThat(commitProcessed.getCount()).isGreaterThan(0);
commitWorkStream.shutdown();
+ streamInfo.onDone.get();
+
commitProcessed.await();
assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED);
}
@Test
- public void testCommitWorkItem_afterShutdown() {
+ public void testCommitWorkItem_abortsCommitsSentAfterShutdown()
+ throws InterruptedException, ExecutionException {
int numCommits = 5;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
- CommitWorkStreamTestStub testStub =
- new CommitWorkStreamTestStub(new TestCommitWorkStreamRequestObserver());
- GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
+ commitWorkStream.shutdown();
+ assertNotNull(streamInfo.onDone.get());
+ AtomicBoolean allAborted = new AtomicBoolean(true);
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
- assertTrue(batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {}));
+ assertTrue(
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ (status) -> {
+ if (status != Windmill.CommitStatus.ABORTED) {
+ allAborted.set(false);
+ }
+ commitProcessed.countDown();
+ }));
}
}
- commitWorkStream.shutdown();
+ commitProcessed.await();
+ assertTrue(allAborted.get());
+ }
+
+ @Test
+ public void testCommitWorkItem_retryOnNewStream() throws Exception {
+ int numCommits = 5;
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
- AtomicReference commitStatus = new AtomicReference<>();
+ final AtomicBoolean allOk = new AtomicBoolean(true);
+ final CountDownLatch firstResponsesDone = new CountDownLatch(2);
+ final CountDownLatch secondResponsesDone = new CountDownLatch(3);
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
+ int finalI = i;
assertTrue(
- batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set));
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ (status) -> {
+ if (status != Windmill.CommitStatus.OK) {
+ allOk.set(false);
+ }
+ if (finalI == 0 || finalI == 4) {
+ firstResponsesDone.countDown();
+ } else {
+ secondResponsesDone.countDown();
+ }
+ }));
}
}
+ Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take();
+ assertEquals(5, request.getCommitChunkCount());
+ for (int i = 0; i < 5; ++i) {
+ assertEquals(i + 1, request.getCommitChunk(i).getRequestId());
+ Windmill.WorkItemCommitRequest parsedRequest =
+ Windmill.WorkItemCommitRequest.parseFrom(
+ request.getCommitChunk(i).getSerializedWorkItemCommit());
+ assertEquals(parsedRequest.getWorkToken(), i);
+ }
+ // Send back that 1 and 5 finished.
+ streamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).addRequestId(5).build());
+ firstResponsesDone.await();
+
+ // Simulate that the server breaks.
+ streamInfo.responseObserver.onError(new IOException("test error"));
+
+ // The stream should reconnect and retry the requests.
+ FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo =
+ waitForConnectionAndConsumeHeader();
+ Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take();
+ assertEquals(3, reconnectRequest.getCommitChunkCount());
+ for (int i = 0; i < 3; ++i) {
+ assertEquals(i + 2, reconnectRequest.getCommitChunk(i).getRequestId());
+ Windmill.WorkItemCommitRequest parsedRequest =
+ Windmill.WorkItemCommitRequest.parseFrom(
+ reconnectRequest.getCommitChunk(i).getSerializedWorkItemCommit());
+ assertEquals(i + 1, parsedRequest.getWorkToken());
+ }
+ // Send back that 2 and 3 finished.
+ reconnectStreamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).addRequestId(3).build());
+ reconnectStreamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(4).build());
+ secondResponsesDone.await();
- assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED);
+ assertThat(reconnectStreamInfo.requests).isEmpty();
+ assertThat(streamInfo.requests).isEmpty();
+ assertTrue(allOk.get());
}
@Test
- public void testSend_notCalledAfterShutdown() {
+ public void testCommitWorkItem_retryOnNewStreamHalfClose() throws Exception {
int numCommits = 5;
- CountDownLatch commitProcessed = new CountDownLatch(numCommits);
-
- TestCommitWorkStreamRequestObserver requestObserver =
- spy(new TestCommitWorkStreamRequestObserver());
- InOrder requestObserverVerifier = inOrder(requestObserver);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
- CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver);
- GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+ final AtomicBoolean allOk = new AtomicBoolean(true);
+ final CountDownLatch firstResponsesDone = new CountDownLatch(2);
+ final CountDownLatch secondResponsesDone = new CountDownLatch(3);
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
+ int finalI = i;
assertTrue(
batcher.commitWorkItem(
COMPUTATION_ID,
workItemCommitRequest(i),
- commitStatus -> commitProcessed.countDown()));
+ (status) -> {
+ if (status != Windmill.CommitStatus.OK) {
+ allOk.set(false);
+ }
+ if (finalI == 0 || finalI == 4) {
+ firstResponsesDone.countDown();
+ } else {
+ secondResponsesDone.countDown();
+ }
+ }));
}
+ }
+ Windmill.StreamingCommitWorkRequest request = streamInfo.requests.take();
+ assertEquals(5, request.getCommitChunkCount());
+ for (int i = 0; i < 5; ++i) {
+ assertEquals(i + 1, request.getCommitChunk(i).getRequestId());
+ Windmill.WorkItemCommitRequest parsedRequest =
+ Windmill.WorkItemCommitRequest.parseFrom(
+ request.getCommitChunk(i).getSerializedWorkItemCommit());
+ assertEquals(parsedRequest.getWorkToken(), i);
+ }
+ // Half-close the logical stream. This shouldn't prevent reconnection of the physical stream
+ // from succeeding.
+ commitWorkStream.halfClose();
+ assertNull(streamInfo.onDone.get());
+
+ // Send back that 1 and 5 finished.
+ streamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(1).addRequestId(5).build());
+ firstResponsesDone.await();
+
+ // Simulate that the server breaks.
+ streamInfo.responseObserver.onError(new IOException("test error"));
+
+ // The stream should reconnect and retry the requests.
+ FakeWindmillGrpcService.CommitStreamInfo reconnectStreamInfo =
+ waitForConnectionAndConsumeHeader();
+
+ // We don't expect any more streams since we finish successfully below.
+ fakeService.expectNoMoreStreams();
+
+ Windmill.StreamingCommitWorkRequest reconnectRequest = reconnectStreamInfo.requests.take();
+ assertEquals(3, reconnectRequest.getCommitChunkCount());
+ for (int i = 0; i < 3; ++i) {
+ assertEquals(i + 2, reconnectRequest.getCommitChunk(i).getRequestId());
+ Windmill.WorkItemCommitRequest parsedRequest =
+ Windmill.WorkItemCommitRequest.parseFrom(
+ reconnectRequest.getCommitChunk(i).getSerializedWorkItemCommit());
+ assertEquals(i + 1, parsedRequest.getWorkToken());
+ }
+ assertNull(streamInfo.onDone.get());
+
+ // Send back that 2 and 3 finished and then 4 finishes.
+ reconnectStreamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(2).addRequestId(3).build());
+ reconnectStreamInfo.responseObserver.onNext(
+ Windmill.StreamingCommitResponse.newBuilder().addRequestId(4).build());
+ // The half-close completes
+ reconnectStreamInfo.responseObserver.onCompleted();
+ secondResponsesDone.await();
+
+ assertThat(reconnectStreamInfo.requests).isEmpty();
+ assertThat(streamInfo.requests).isEmpty();
+ assertTrue(allOk.get());
+ }
+
+ @Test
+ public void testSend_notCalledAfterShutdown_Single()
+ throws ExecutionException, InterruptedException {
+ int numCommits = 1;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
+
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
+ assertTrue(
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(0),
+ commitStatus -> {
+ errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED));
+ errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L));
+ commitProcessed.countDown();
+ }));
// Shutdown the stream before we exit the try-with-resources block which will try to send()
// the batched request.
commitWorkStream.shutdown();
}
+ commitProcessed.await();
- // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends
- // the header, which happens before we shutdown.
- requestObserverVerifier
- .verify(requestObserver)
- .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER)));
- requestObserverVerifier.verify(requestObserver).onError(any());
- requestObserverVerifier.verifyNoMoreInteractions();
+ assertNotNull(streamInfo.onDone.get());
+ assertThat(streamInfo.requests).isEmpty();
}
- private static class TestCommitWorkStreamRequestObserver
- implements StreamObserver {
- private @Nullable StreamObserver responseObserver;
-
- @Override
- public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {}
-
- @Override
- public void onError(Throwable throwable) {}
+ @Test
+ public void testSend_notCalledAfterShutdown_Batch()
+ throws ExecutionException, InterruptedException {
+ int numCommits = 2;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
- @Override
- public void onCompleted() {
- if (responseObserver != null) {
- responseObserver.onCompleted();
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ assertTrue(
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ commitStatus -> {
+ errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED));
+ errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L));
+ commitProcessed.countDown();
+ }));
}
+ // Shutdown the stream before we exit the try-with-resources block which will try to send()
+ // the batched request.
+ commitWorkStream.shutdown();
}
+ commitProcessed.await();
+
+ assertNotNull(streamInfo.onDone.get());
+ assertThat(streamInfo.requests).isEmpty();
}
- private static class CommitWorkStreamTestStub
- extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
- private final TestCommitWorkStreamRequestObserver requestObserver;
- private @Nullable StreamObserver responseObserver;
+ @Test
+ public void testSend_notCalledAfterShutdown_Multichunk()
+ throws ExecutionException, InterruptedException {
+ int numCommits = 1;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream();
+ FakeWindmillGrpcService.CommitStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
- private CommitWorkStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) {
- this.requestObserver = requestObserver;
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
+ assertTrue(
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(0)
+ .toBuilder()
+ .addBagUpdates(Windmill.TagBag.newBuilder().setTag(LARGE_BYTE_STRING).build())
+ .build(),
+ commitStatus -> {
+ errorCollector.checkThat(commitStatus, equalTo(Windmill.CommitStatus.ABORTED));
+ errorCollector.checkThat(commitProcessed.getCount(), greaterThan(0L));
+ commitProcessed.countDown();
+ }));
+ // Shutdown the stream before we exit the try-with-resources block which will try to send()
+ // the batched request.
+ commitWorkStream.shutdown();
}
+ commitProcessed.await();
+ assertNotNull(streamInfo.onDone.get());
+ assertThat(streamInfo.requests).isEmpty();
+ }
- @Override
- public StreamObserver commitWorkStream(
- StreamObserver responseObserver) {
- if (this.responseObserver == null) {
- ((ServerCallStreamObserver) responseObserver)
- .setOnCancelHandler(() -> {});
- this.responseObserver = responseObserver;
- requestObserver.responseObserver = this.responseObserver;
- }
-
- return requestObserver;
+ private FakeWindmillGrpcService.CommitStreamInfo waitForConnectionAndConsumeHeader() {
+ try {
+ FakeWindmillGrpcService.CommitStreamInfo info = fakeService.waitForConnectedCommitStream();
+ Windmill.StreamingCommitWorkRequest request = info.requests.take();
+ errorCollector.checkThat(request.getHeader(), is(TEST_JOB_HEADER));
+ assertEquals(0, request.getCommitChunkCount());
+ return info;
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
}
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
index 9b0cc006464f..1014242317de 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
@@ -81,7 +81,11 @@ public class GrpcDirectGetWorkStreamTest {
private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest";
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+
+ @Rule
+ public transient Timeout globalTimeout =
+ Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build();
+
private ManagedChannel inProcessChannel;
private GrpcDirectGetWorkStream stream;
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java
index 8a23b4f51b5a..150db4ed4815 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java
@@ -36,6 +36,8 @@
@RunWith(JUnit4.class)
public class GrpcGetDataStreamRequestsTest {
+ private static final int DEADLINE_SECONDS = 10;
+
@Test
public void testQueuedRequest_globalRequestsFirstComparator() {
List requests = new ArrayList<>();
@@ -49,7 +51,7 @@ public void testQueuedRequest_globalRequestsFirstComparator() {
.build();
requests.add(
GrpcGetDataStreamRequests.QueuedRequest.forComputation(
- 1, "computation1", keyedGetDataRequest1));
+ 1, "computation1", keyedGetDataRequest1, DEADLINE_SECONDS));
Windmill.KeyedGetDataRequest keyedGetDataRequest2 =
Windmill.KeyedGetDataRequest.newBuilder()
@@ -61,7 +63,7 @@ public void testQueuedRequest_globalRequestsFirstComparator() {
.build();
requests.add(
GrpcGetDataStreamRequests.QueuedRequest.forComputation(
- 2, "computation2", keyedGetDataRequest2));
+ 2, "computation2", keyedGetDataRequest2, DEADLINE_SECONDS));
Windmill.GlobalDataRequest globalDataRequest =
Windmill.GlobalDataRequest.newBuilder()
@@ -72,7 +74,8 @@ public void testQueuedRequest_globalRequestsFirstComparator() {
.build())
.setComputationId("computation1")
.build();
- requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest));
+ requests.add(
+ GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest, DEADLINE_SECONDS));
requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst());
@@ -94,7 +97,7 @@ public void testQueuedBatch_asGetDataRequest() {
.build();
queuedBatch.addRequest(
GrpcGetDataStreamRequests.QueuedRequest.forComputation(
- 1, "computation1", keyedGetDataRequest1));
+ 1, "computation1", keyedGetDataRequest1, DEADLINE_SECONDS));
Windmill.KeyedGetDataRequest keyedGetDataRequest2 =
Windmill.KeyedGetDataRequest.newBuilder()
@@ -106,7 +109,7 @@ public void testQueuedBatch_asGetDataRequest() {
.build();
queuedBatch.addRequest(
GrpcGetDataStreamRequests.QueuedRequest.forComputation(
- 2, "computation2", keyedGetDataRequest2));
+ 2, "computation2", keyedGetDataRequest2, DEADLINE_SECONDS));
Windmill.GlobalDataRequest globalDataRequest =
Windmill.GlobalDataRequest.newBuilder()
@@ -117,7 +120,8 @@ public void testQueuedBatch_asGetDataRequest() {
.build())
.setComputationId("computation1")
.build();
- queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest));
+ queuedBatch.addRequest(
+ GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest, DEADLINE_SECONDS));
Windmill.StreamingGetDataRequest getDataRequest = queuedBatch.asGetDataRequest();
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java
index 4cf21c2adfdf..e954f2cc7105 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java
@@ -18,9 +18,11 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static com.google.common.truth.Truth.assertThat;
-import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import java.io.IOException;
import java.util.List;
@@ -32,7 +34,6 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
@@ -41,16 +42,13 @@
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.ServerCallStreamObserver;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.testing.GrpcCleanupRule;
-import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.util.MutableHandlerRegistry;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
+import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
-import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ErrorCollector;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -65,34 +63,38 @@ public class GrpcGetDataStreamTest {
.setProjectId("test_project")
.build();
+ @Rule public final ErrorCollector errorCollector = new ErrorCollector();
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
- private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+
+ @Rule
+ public transient Timeout globalTimeout =
+ Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build();
+
+ private final FakeWindmillGrpcService fakeService = new FakeWindmillGrpcService(errorCollector);
private ManagedChannel inProcessChannel;
+ private Server inProcessServer;
@Before
public void setUp() throws IOException {
- Server server =
- InProcessServerBuilder.forName(FAKE_SERVER_NAME)
- .fallbackHandlerRegistry(serviceRegistry)
- .directExecutor()
- .build()
- .start();
-
+ inProcessServer =
+ grpcCleanup.register(
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .addService(fakeService)
+ .directExecutor()
+ .build()
+ .start());
inProcessChannel =
grpcCleanup.register(
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
- grpcCleanup.register(server);
- grpcCleanup.register(inProcessChannel);
}
@After
public void cleanUp() {
+ inProcessServer.shutdownNow();
inProcessChannel.shutdownNow();
}
- private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) {
- serviceRegistry.addService(testStub);
+ private GrpcGetDataStream createGetDataStream() {
GrpcGetDataStream getDataStream =
(GrpcGetDataStream)
GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
@@ -104,53 +106,51 @@ private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) {
}
@Test
- public void testRequestKeyedData() {
- GetDataStreamTestStub testStub =
- new GetDataStreamTestStub(new TestGetDataStreamRequestObserver());
- GrpcGetDataStream getDataStream = createGetDataStream(testStub);
+ public void testRequestKeyedData() throws InterruptedException {
+ GrpcGetDataStream getDataStream = createGetDataStream();
+ FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
+
// These will block until they are successfully sent.
+ Windmill.KeyedGetDataRequest keyedGetDataRequest =
+ Windmill.KeyedGetDataRequest.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(1)
+ .setCacheToken(1)
+ .setWorkToken(1)
+ .build();
+
CompletableFuture sendFuture =
CompletableFuture.supplyAsync(
() -> {
try {
- return getDataStream.requestKeyedData(
- "computationId",
- Windmill.KeyedGetDataRequest.newBuilder()
- .setKey(ByteString.EMPTY)
- .setShardingKey(1)
- .setCacheToken(1)
- .setWorkToken(1)
- .build());
+ return getDataStream.requestKeyedData("computationId", keyedGetDataRequest);
} catch (WindmillStreamShutdownException e) {
throw new RuntimeException(e);
}
});
- // Sleep a bit to allow future to run.
- Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+ Windmill.StreamingGetDataRequest request = streamInfo.requests.take();
+ assertThat(request.getRequestIdList()).containsExactly(1L);
+ assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0));
- Windmill.KeyedGetDataResponse response =
+ Windmill.KeyedGetDataResponse keyedGetDataResponse =
Windmill.KeyedGetDataResponse.newBuilder()
.setShardingKey(1)
.setKey(ByteString.EMPTY)
.build();
- testStub.injectResponse(
+ streamInfo.responseObserver.onNext(
Windmill.StreamingGetDataResponse.newBuilder()
.addRequestId(1)
- .addSerializedResponse(response.toByteString())
- .setRemainingBytesForResponse(0)
+ .addSerializedResponse(keyedGetDataResponse.toByteString())
.build());
- assertThat(sendFuture.join()).isEqualTo(response);
+ assertThat(sendFuture.join()).isEqualTo(keyedGetDataResponse);
}
@Test
- @Ignore("https://github.com/apache/beam/issues/28957")
public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() {
- GetDataStreamTestStub testStub =
- new GetDataStreamTestStub(new TestGetDataStreamRequestObserver());
- GrpcGetDataStream getDataStream = createGetDataStream(testStub);
+ GrpcGetDataStream getDataStream = createGetDataStream();
int numSendThreads = 5;
ExecutorService getDataStreamSenders = Executors.newFixedThreadPool(numSendThreads);
CountDownLatch waitForSendAttempt = new CountDownLatch(1);
@@ -210,48 +210,106 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow
}
}
- private static class TestGetDataStreamRequestObserver
- implements StreamObserver {
- private @Nullable StreamObserver responseObserver;
+ @Test
+ public void testRequestKeyedData_reconnectOnStreamError() throws InterruptedException {
+ GrpcGetDataStream getDataStream = createGetDataStream();
+ FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
+
+ // These will block until they are successfully sent.
+ Windmill.KeyedGetDataRequest keyedGetDataRequest =
+ Windmill.KeyedGetDataRequest.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(1)
+ .setCacheToken(1)
+ .setWorkToken(1)
+ .build();
+
+ CompletableFuture sendFuture =
+ CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ return getDataStream.requestKeyedData("computationId", keyedGetDataRequest);
+ } catch (WindmillStreamShutdownException e) {
+ throw new RuntimeException(e);
+ }
+ });
- @Override
- public void onNext(Windmill.StreamingGetDataRequest streamingGetDataRequest) {}
+ Windmill.StreamingGetDataRequest request = streamInfo.requests.take();
+ assertThat(request.getRequestIdList()).containsExactly(1L);
+ assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0));
- @Override
- public void onError(Throwable throwable) {}
+ // Simulate an error on the grpc stream, this should trigger a retry of the request internal to
+ // the stream.
+ streamInfo.responseObserver.onError(new IOException("test error"));
- @Override
- public void onCompleted() {
- if (responseObserver != null) {
- responseObserver.onCompleted();
+ streamInfo = waitForConnectionAndConsumeHeader();
+ while (true) {
+ request = streamInfo.requests.poll(5, TimeUnit.SECONDS);
+ if (request != null) break;
+ if (sendFuture.isDone()) {
+ fail("Unexpected send completion " + sendFuture);
}
}
+ assertThat(request.getRequestIdList()).containsExactly(1L);
+ assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0));
+
+ getDataStream.shutdown();
+ assertThrows(RuntimeException.class, sendFuture::join);
}
- private static class GetDataStreamTestStub
- extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
- private final TestGetDataStreamRequestObserver requestObserver;
- private @Nullable StreamObserver responseObserver;
+ @Test
+ public void testRequestKeyedData_reconnectOnStreamErrorAfterHalfClose()
+ throws InterruptedException, ExecutionException {
+ GrpcGetDataStream getDataStream = createGetDataStream();
+ FakeWindmillGrpcService.GetDataStreamInfo streamInfo = waitForConnectionAndConsumeHeader();
- private GetDataStreamTestStub(TestGetDataStreamRequestObserver requestObserver) {
- this.requestObserver = requestObserver;
- }
+ // These will block until they are successfully sent.
+ Windmill.KeyedGetDataRequest keyedGetDataRequest =
+ Windmill.KeyedGetDataRequest.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(1)
+ .setCacheToken(1)
+ .setWorkToken(1)
+ .build();
- @Override
- public StreamObserver getDataStream(
- StreamObserver responseObserver) {
- if (this.responseObserver == null) {
- ((ServerCallStreamObserver) responseObserver)
- .setOnCancelHandler(() -> {});
- this.responseObserver = responseObserver;
- requestObserver.responseObserver = this.responseObserver;
- }
+ CompletableFuture sendFuture =
+ CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ return getDataStream.requestKeyedData("computationId", keyedGetDataRequest);
+ } catch (WindmillStreamShutdownException e) {
+ throw new RuntimeException(e);
+ }
+ });
- return requestObserver;
- }
+ Windmill.StreamingGetDataRequest request = streamInfo.requests.take();
+ assertThat(request.getRequestIdList()).containsExactly(1L);
+ assertEquals(keyedGetDataRequest, request.getStateRequest(0).getRequests(0));
+
+ // Close the stream.
+ getDataStream.halfClose();
+ assertNull(streamInfo.onDone.get());
+
+ // Simulate an error on the grpc stream, this should trigger an error on all
+ // existing requests but no new connection since we half-closed and nothing left after
+ // responding with errors.
+ fakeService.expectNoMoreStreams();
+ streamInfo.responseObserver.onError(new IOException("test error"));
+ assertThrows(RuntimeException.class, sendFuture::join);
+
+ getDataStream.shutdown();
+ }
- private void injectResponse(Windmill.StreamingGetDataResponse getDataResponse) {
- checkNotNull(responseObserver).onNext(getDataResponse);
+ private FakeWindmillGrpcService.GetDataStreamInfo waitForConnectionAndConsumeHeader() {
+ try {
+ FakeWindmillGrpcService.GetDataStreamInfo info = fakeService.waitForConnectedGetDataStream();
+ Windmill.StreamingGetDataRequest request = info.requests.take();
+ errorCollector.checkThat(request.getHeader(), Matchers.is(TEST_JOB_HEADER));
+ assertEquals(0, request.getRequestIdCount());
+ assertEquals(0, request.getComputationHeartbeatRequestCount());
+ return info;
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
}
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
index 9a9f93d5c6c9..3c3e2a6579f6 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java
@@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
import static com.google.common.truth.Truth.assertThat;
-import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -237,7 +236,8 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() {
}
@Test
- public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() {
+ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry()
+ throws InterruptedException {
GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer());
@@ -250,7 +250,9 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() {
assertTrue(streamFactory.streamRegistry().contains(stream));
stream.halfClose();
- assertFalse(streamFactory.streamRegistry().contains(stream));
+ while (streamFactory.streamRegistry().contains(stream)) {
+ Thread.sleep(100);
+ }
}
@Test
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
index 899d99b50352..e52b6e8de4bf 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
@@ -117,7 +117,11 @@ public class GrpcWindmillServerTest {
private final long clientId = 10L;
private final Set openedChannels = new HashSet<>();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
- @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+
+ @Rule
+ public transient Timeout globalTimeout =
+ Timeout.builder().withTimeout(10, TimeUnit.MINUTES).withLookingForStuckThread(true).build();
+
@Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@Rule public ErrorCollector errorCollector = new ErrorCollector();
private Server server;
@@ -482,7 +486,6 @@ public void onCompleted() {
}
@Test
- @SuppressWarnings("FutureReturnValueIgnored")
public void testStreamingGetData() throws Exception {
// This server responds to GetDataRequests with responses that mirror the requests.
serviceRegistry.addService(
@@ -623,7 +626,7 @@ private void flushResponse() {
for (int i = 0; i < 100; ++i) {
final String key = "key" + i;
final String s = i % 5 == 0 ? largeString(i) : "tag";
- executor.submit(
+ executor.execute(
() -> {
try {
errorCollector.checkThat(