Skip to content

Commit 6be76ae

Browse files
authored
[Dataflow Streaming] Enforce that get data requests for the same work item are not batched. (#36474)
Such batches are rejected by the backend so we should prevent them to avoid stuckness. Parallel requests for the same work item are unexpected but could be caused by bugs in the harness or by incorrect parallel state fetching from a single bundle.
1 parent a584688 commit 6be76ae

File tree

4 files changed

+316
-124
lines changed

4 files changed

+316
-124
lines changed

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class GetDataPhysicalStreamHandler extends PhysicalStreamHandler {
193193
public void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException {
194194
// Synchronization of pending inserts is necessary with send to ensure duplicates are not
195195
// sent on stream reconnect.
196-
for (QueuedRequest request : batch.requestsReadOnly()) {
196+
for (QueuedRequest request : batch.requestsView()) {
197197
boolean alreadyPresent = pending.put(request.id(), request.getResponseStream()) != null;
198198
verify(!alreadyPresent, "Request already sent, id: %s", request.id());
199199
}
@@ -277,7 +277,7 @@ protected synchronized void onFlushPending(boolean isNewStream)
277277
}
278278
while (!batches.isEmpty()) {
279279
QueuedBatch batch = checkNotNull(batches.peekFirst());
280-
verify(!batch.isEmpty());
280+
verify(batch.requestsCount() > 0);
281281
if (!batch.isFinalized()) {
282282
break;
283283
}
@@ -482,17 +482,15 @@ private void queueRequestAndWait(QueuedRequest request)
482482

483483
batch = batches.isEmpty() ? null : batches.getLast();
484484
if (batch == null
485-
|| batch.isFinalized()
486-
|| batch.requestsCount() >= streamingRpcBatchLimit
487-
|| batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
488-
if (batch != null) {
489-
prevBatch = batch;
490-
}
485+
|| !batch.tryAddRequest(
486+
request, streamingRpcBatchLimit, AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE)) {
487+
// We need a new batch.
488+
prevBatch = batch; // may be null
491489
batch = new QueuedBatch();
492490
batches.addLast(batch);
493491
responsibleForSend = true;
492+
verify(batch.tryAddRequest(request, Integer.MAX_VALUE, Long.MAX_VALUE));
494493
}
495-
batch.addRequest(request);
496494
}
497495
if (responsibleForSend) {
498496
if (prevBatch == null) {
@@ -532,7 +530,7 @@ private synchronized void trySendBatch(QueuedBatch batch) throws WindmillStreamS
532530
// an error and will
533531
// resend requests (possibly with new batching).
534532
verify(batch == batches.pollFirst());
535-
verify(!batch.isEmpty());
533+
verify(batch.requestsCount() > 0);
536534
currentGetDataPhysicalStream.sendBatch(batch);
537535
// Notify all waiters with requests in this batch as well as the sender
538536
// of the next batch (if one exists).

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java

Lines changed: 107 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@
1919

2020
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
2121

22-
import com.google.auto.value.AutoOneOf;
2322
import java.util.ArrayList;
2423
import java.util.Collections;
2524
import java.util.Comparator;
25+
import java.util.HashSet;
2626
import java.util.List;
2727
import java.util.concurrent.CountDownLatch;
28-
import java.util.stream.Stream;
28+
import javax.annotation.Nullable;
2929
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
30-
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
3130
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
3231
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
3332
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
33+
import org.apache.beam.sdk.util.Preconditions;
3434
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
3535
import org.slf4j.Logger;
3636
import org.slf4j.LoggerFactory;
@@ -46,15 +46,42 @@ private static String debugFormat(long value) {
4646
return String.format("%016x", value);
4747
}
4848

49+
static class ComputationAndKeyRequest {
50+
private final String computation;
51+
private final KeyedGetDataRequest request;
52+
53+
ComputationAndKeyRequest(String computation, KeyedGetDataRequest request) {
54+
this.computation = computation;
55+
this.request = request;
56+
}
57+
58+
String getComputation() {
59+
return computation;
60+
}
61+
62+
KeyedGetDataRequest getKeyedGetDataRequest() {
63+
return request;
64+
}
65+
}
66+
4967
static class QueuedRequest {
5068
private final long id;
51-
private final ComputationOrGlobalDataRequest dataRequest;
69+
private final @Nullable ComputationAndKeyRequest computationAndKeyRequest;
70+
private final @Nullable GlobalDataRequest globalDataRequest;
5271
private AppendableInputStream responseStream;
5372

73+
private QueuedRequest(long id, GlobalDataRequest globalDataRequest, long deadlineSeconds) {
74+
this.id = id;
75+
this.computationAndKeyRequest = null;
76+
this.globalDataRequest = globalDataRequest;
77+
responseStream = new AppendableInputStream(deadlineSeconds);
78+
}
79+
5480
private QueuedRequest(
55-
long id, ComputationOrGlobalDataRequest dataRequest, long deadlineSeconds) {
81+
long id, ComputationAndKeyRequest computationAndKeyRequest, long deadlineSeconds) {
5682
this.id = id;
57-
this.dataRequest = dataRequest;
83+
this.computationAndKeyRequest = computationAndKeyRequest;
84+
this.globalDataRequest = null;
5885
responseStream = new AppendableInputStream(deadlineSeconds);
5986
}
6087

@@ -63,27 +90,19 @@ static QueuedRequest forComputation(
6390
String computation,
6491
KeyedGetDataRequest keyedGetDataRequest,
6592
long deadlineSeconds) {
66-
ComputationGetDataRequest computationGetDataRequest =
67-
ComputationGetDataRequest.newBuilder()
68-
.setComputationId(computation)
69-
.addRequests(keyedGetDataRequest)
70-
.build();
7193
return new QueuedRequest(
72-
id,
73-
ComputationOrGlobalDataRequest.computation(computationGetDataRequest),
74-
deadlineSeconds);
94+
id, new ComputationAndKeyRequest(computation, keyedGetDataRequest), deadlineSeconds);
7595
}
7696

7797
static QueuedRequest global(
7898
long id, GlobalDataRequest globalDataRequest, long deadlineSeconds) {
79-
return new QueuedRequest(
80-
id, ComputationOrGlobalDataRequest.global(globalDataRequest), deadlineSeconds);
99+
return new QueuedRequest(id, globalDataRequest, deadlineSeconds);
81100
}
82101

83102
static Comparator<QueuedRequest> globalRequestsFirst() {
84103
return (QueuedRequest r1, QueuedRequest r2) -> {
85-
boolean r1gd = r1.dataRequest.isGlobal();
86-
boolean r2gd = r2.dataRequest.isGlobal();
104+
boolean r1gd = r1.getKind() == Kind.GLOBAL;
105+
boolean r2gd = r2.getKind() == Kind.GLOBAL;
87106
return r1gd == r2gd ? 0 : (r1gd ? -1 : 1);
88107
};
89108
}
@@ -93,7 +112,13 @@ long id() {
93112
}
94113

95114
long byteSize() {
96-
return dataRequest.serializedSize();
115+
if (globalDataRequest != null) {
116+
return globalDataRequest.getSerializedSize();
117+
}
118+
Preconditions.checkStateNotNull(computationAndKeyRequest);
119+
return 10L
120+
+ computationAndKeyRequest.request.getSerializedSize()
121+
+ computationAndKeyRequest.getComputation().length();
97122
}
98123

99124
AppendableInputStream getResponseStream() {
@@ -104,22 +129,56 @@ void resetResponseStream() {
104129
this.responseStream = new AppendableInputStream(responseStream.getDeadlineSeconds());
105130
}
106131

107-
public ComputationOrGlobalDataRequest getDataRequest() {
108-
return dataRequest;
132+
enum Kind {
133+
COMPUTATION_AND_KEY_REQUEST,
134+
GLOBAL
135+
}
136+
137+
Kind getKind() {
138+
return computationAndKeyRequest != null ? Kind.COMPUTATION_AND_KEY_REQUEST : Kind.GLOBAL;
139+
}
140+
141+
ComputationAndKeyRequest getComputationAndKeyRequest() {
142+
return Preconditions.checkStateNotNull(computationAndKeyRequest);
143+
}
144+
145+
GlobalDataRequest getGlobalDataRequest() {
146+
return Preconditions.checkStateNotNull(globalDataRequest);
109147
}
110148

111149
void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder builder) {
112150
builder.addRequestId(id);
113-
if (dataRequest.isForComputation()) {
114-
builder.addStateRequest(dataRequest.computation());
115-
} else {
116-
builder.addGlobalDataRequest(dataRequest.global());
151+
switch (getKind()) {
152+
case COMPUTATION_AND_KEY_REQUEST:
153+
ComputationAndKeyRequest request = getComputationAndKeyRequest();
154+
builder
155+
.addStateRequestBuilder()
156+
.setComputationId(request.getComputation())
157+
.addRequests(request.request);
158+
break;
159+
case GLOBAL:
160+
builder.addGlobalDataRequest(getGlobalDataRequest());
161+
break;
117162
}
118163
}
119164

120165
@Override
121166
public final String toString() {
122-
return "QueuedRequest{" + "dataRequest=" + dataRequest + ", id=" + id + '}';
167+
StringBuilder result = new StringBuilder("QueuedRequest{id=").append(id).append(", ");
168+
if (getKind() == Kind.GLOBAL) {
169+
result.append("GetSideInput=").append(getGlobalDataRequest());
170+
} else {
171+
KeyedGetDataRequest key = getComputationAndKeyRequest().request;
172+
result
173+
.append("KeyedGetState=[shardingKey=")
174+
.append(debugFormat(key.getShardingKey()))
175+
.append("cacheToken=")
176+
.append(debugFormat(key.getCacheToken()))
177+
.append("workToken")
178+
.append(debugFormat(key.getWorkToken()))
179+
.append("]");
180+
}
181+
return result.append('}').toString();
123182
}
124183
}
125184

@@ -128,13 +187,14 @@ public final String toString() {
128187
*/
129188
static class QueuedBatch {
130189
private final List<QueuedRequest> requests = new ArrayList<>();
190+
private final HashSet<Long> workTokens = new HashSet<>();
131191
private final CountDownLatch sent = new CountDownLatch(1);
132192
private long byteSize = 0;
133193
private volatile boolean finalized = false;
134194
private volatile boolean failed = false;
135195

136196
/** Returns a read-only view of requests. */
137-
List<QueuedRequest> requestsReadOnly() {
197+
List<QueuedRequest> requestsView() {
138198
return Collections.unmodifiableList(requests);
139199
}
140200

@@ -155,18 +215,10 @@ Windmill.StreamingGetDataRequest asGetDataRequest() {
155215
return builder.build();
156216
}
157217

158-
boolean isEmpty() {
159-
return requests.isEmpty();
160-
}
161-
162218
int requestsCount() {
163219
return requests.size();
164220
}
165221

166-
long byteSize() {
167-
return byteSize;
168-
}
169-
170222
boolean isFinalized() {
171223
return finalized;
172224
}
@@ -176,9 +228,26 @@ void markFinalized() {
176228
}
177229

178230
/** Adds a request to the batch. */
179-
void addRequest(QueuedRequest request) {
231+
boolean tryAddRequest(QueuedRequest request, int countLimit, long byteLimit) {
232+
if (finalized) {
233+
return false;
234+
}
235+
if (requests.size() >= countLimit) {
236+
return false;
237+
}
238+
long estimatedBytes = request.byteSize();
239+
if (byteSize + estimatedBytes >= byteLimit) {
240+
return false;
241+
}
242+
243+
if (request.getKind() == QueuedRequest.Kind.COMPUTATION_AND_KEY_REQUEST
244+
&& !workTokens.add(request.getComputationAndKeyRequest().request.getWorkToken())) {
245+
return false;
246+
}
247+
// At this point we have added to work items so we must accept the item.
180248
requests.add(request);
181-
byteSize += request.byteSize();
249+
byteSize += estimatedBytes;
250+
return true;
182251
}
183252

184253
/**
@@ -227,75 +296,9 @@ void waitForSendOrFailNotification()
227296

228297
private ImmutableList<String> createStreamCancelledErrorMessages() {
229298
return requests.stream()
230-
.flatMap(
231-
request -> {
232-
switch (request.getDataRequest().getKind()) {
233-
case GLOBAL:
234-
return Stream.of("GetSideInput=" + request.getDataRequest().global());
235-
case COMPUTATION:
236-
return request.getDataRequest().computation().getRequestsList().stream()
237-
.map(
238-
keyedRequest ->
239-
"KeyedGetState=["
240-
+ "shardingKey="
241-
+ debugFormat(keyedRequest.getShardingKey())
242-
+ "cacheToken="
243-
+ debugFormat(keyedRequest.getCacheToken())
244-
+ "workToken"
245-
+ debugFormat(keyedRequest.getWorkToken())
246-
+ "]");
247-
default:
248-
// Will never happen switch is exhaustive.
249-
throw new IllegalStateException();
250-
}
251-
})
299+
.map(QueuedRequest::toString)
252300
.limit(STREAM_CANCELLED_ERROR_LOG_LIMIT)
253301
.collect(toImmutableList());
254302
}
255303
}
256-
257-
@AutoOneOf(ComputationOrGlobalDataRequest.Kind.class)
258-
abstract static class ComputationOrGlobalDataRequest {
259-
static ComputationOrGlobalDataRequest computation(
260-
ComputationGetDataRequest computationGetDataRequest) {
261-
return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest.computation(
262-
computationGetDataRequest);
263-
}
264-
265-
static ComputationOrGlobalDataRequest global(GlobalDataRequest globalDataRequest) {
266-
return AutoOneOf_GrpcGetDataStreamRequests_ComputationOrGlobalDataRequest.global(
267-
globalDataRequest);
268-
}
269-
270-
abstract Kind getKind();
271-
272-
abstract ComputationGetDataRequest computation();
273-
274-
abstract GlobalDataRequest global();
275-
276-
boolean isGlobal() {
277-
return getKind() == Kind.GLOBAL;
278-
}
279-
280-
boolean isForComputation() {
281-
return getKind() == Kind.COMPUTATION;
282-
}
283-
284-
long serializedSize() {
285-
switch (getKind()) {
286-
case GLOBAL:
287-
return global().getSerializedSize();
288-
case COMPUTATION:
289-
return computation().getSerializedSize();
290-
// this will never happen since the switch is exhaustive.
291-
default:
292-
throw new UnsupportedOperationException("unknown dataRequest type.");
293-
}
294-
}
295-
296-
enum Kind {
297-
COMPUTATION,
298-
GLOBAL
299-
}
300-
}
301304
}

0 commit comments

Comments
 (0)