Skip to content

Commit 7bec450

Browse files
authored
[Dataflow Streaming] Optimize failed key processing by indexing workitems by sharding key (#33755)
1 parent 9064743 commit 7bec450

File tree

6 files changed

+319
-165
lines changed

6 files changed

+319
-165
lines changed

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java

Lines changed: 60 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
package org.apache.beam.runners.dataflow.worker.streaming;
1919

2020
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
21-
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap;
2221

2322
import java.io.PrintWriter;
24-
import java.util.ArrayDeque;
25-
import java.util.Collection;
26-
import java.util.Deque;
2723
import java.util.HashMap;
2824
import java.util.Iterator;
25+
import java.util.LinkedHashMap;
2926
import java.util.Map;
3027
import java.util.Map.Entry;
3128
import java.util.Optional;
@@ -36,14 +33,13 @@
3633
import javax.annotation.concurrent.ThreadSafe;
3734
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
3835
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
36+
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation;
3937
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
4038
import org.apache.beam.sdk.annotations.Internal;
4139
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
4240
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
4341
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
44-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
4542
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
46-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
4743
import org.joda.time.Duration;
4844
import org.joda.time.Instant;
4945
import org.slf4j.Logger;
@@ -63,11 +59,11 @@ public final class ActiveWorkState {
6359
private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50;
6460

6561
/**
66-
* Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link
67-
* Queue<Work>} is actively processing.
62+
* Map from shardingKey to {@link Work} for the key. The first item in the {@link LinkedHashMap}
63+
* is actively processing.
6864
*/
6965
@GuardedBy("this")
70-
private final Map<ShardedKey, Deque<ExecutableWork>> activeWork;
66+
private final Map<Long /*shardingKey*/, LinkedHashMap<WorkId, ExecutableWork>> activeWork;
7167

7268
@GuardedBy("this")
7369
private final WindmillStateCache.ForComputation computationStateCache;
@@ -81,8 +77,8 @@ public final class ActiveWorkState {
8177
private GetWorkBudget activeGetWorkBudget;
8278

8379
private ActiveWorkState(
84-
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
85-
WindmillStateCache.ForComputation computationStateCache) {
80+
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
81+
ForComputation computationStateCache) {
8682
this.activeWork = activeWork;
8783
this.computationStateCache = computationStateCache;
8884
this.activeGetWorkBudget = GetWorkBudget.noBudget();
@@ -94,7 +90,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState
9490

9591
@VisibleForTesting
9692
static ActiveWorkState forTesting(
97-
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
93+
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
9894
WindmillStateCache.ForComputation computationStateCache) {
9995
return new ActiveWorkState(activeWork, computationStateCache);
10096
}
@@ -124,28 +120,30 @@ private static String elapsedString(Instant start, Instant end) {
124120
*/
125121
synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork) {
126122
ShardedKey shardedKey = executableWork.work().getShardedKey();
127-
Deque<ExecutableWork> workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>());
123+
long shardingKey = shardedKey.shardingKey();
124+
LinkedHashMap<WorkId, ExecutableWork> workQueue =
125+
activeWork.computeIfAbsent(shardingKey, (unused) -> new LinkedHashMap<>());
128126
// This key does not have any work queued up on it. Create one, insert Work, and mark the work
129127
// to be executed.
130-
if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
131-
workQueue.addLast(executableWork);
132-
activeWork.put(shardedKey, workQueue);
128+
if (workQueue.isEmpty()) {
129+
workQueue.put(executableWork.id(), executableWork);
133130
incrementActiveWorkBudget(executableWork.work());
134131
return ActivateWorkResult.EXECUTE;
135132
}
136133

137134
// Check to see if we have this work token queued.
138-
Iterator<ExecutableWork> workIterator = workQueue.iterator();
135+
Iterator<Entry<WorkId, ExecutableWork>> workIterator = workQueue.entrySet().iterator();
139136
while (workIterator.hasNext()) {
140-
ExecutableWork queuedWork = workIterator.next();
137+
ExecutableWork queuedWork = workIterator.next().getValue();
141138
if (queuedWork.id().equals(executableWork.id())) {
142139
return ActivateWorkResult.DUPLICATE;
143140
}
144-
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()) {
141+
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()
142+
&& queuedWork.work().getShardedKey().equals(executableWork.work().getShardedKey())) {
145143
if (executableWork.id().workToken() > queuedWork.id().workToken()) {
146144
// Check to see if the queuedWork is active. We only want to remove it if it is NOT
147145
// currently active.
148-
if (!queuedWork.equals(workQueue.peek())) {
146+
if (!queuedWork.equals(Preconditions.checkNotNull(firstValue(workQueue)))) {
149147
workIterator.remove();
150148
decrementActiveWorkBudget(queuedWork.work());
151149
}
@@ -157,7 +155,7 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
157155
}
158156

159157
// Queue the work for later processing.
160-
workQueue.addLast(executableWork);
158+
workQueue.put(executableWork.id(), executableWork);
161159
incrementActiveWorkBudget(executableWork.work());
162160
return ActivateWorkResult.QUEUED;
163161
}
@@ -167,54 +165,29 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
167165
*
168166
* @param failedWork a map from sharding_key to tokens for the corresponding work.
169167
*/
170-
synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
171-
// Note we can't construct a ShardedKey and look it up in activeWork directly since
172-
// HeartbeatResponse doesn't include the user key.
173-
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
174-
Collection<WorkId> failedWorkIds = failedWork.get(entry.getKey().shardingKey());
175-
for (WorkId failedWorkId : failedWorkIds) {
176-
for (ExecutableWork queuedWork : entry.getValue()) {
177-
WorkItem workItem = queuedWork.work().getWorkItem();
178-
if (workItem.getWorkToken() == failedWorkId.workToken()
179-
&& workItem.getCacheToken() == failedWorkId.cacheToken()) {
180-
LOG.debug(
181-
"Failing work "
182-
+ computationStateCache.getComputation()
183-
+ " "
184-
+ entry.getKey().shardingKey()
185-
+ " "
186-
+ failedWorkId.workToken()
187-
+ " "
188-
+ failedWorkId.cacheToken()
189-
+ ". The work will be retried and is not lost.");
190-
queuedWork.work().setFailed();
191-
break;
192-
}
193-
}
168+
synchronized void failWorkForKey(ImmutableList<WorkIdWithShardingKey> failedWork) {
169+
for (WorkIdWithShardingKey failedId : failedWork) {
170+
@Nullable
171+
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(failedId.shardingKey());
172+
if (workQueue == null) {
173+
// Work could complete/fail before heartbeat response arrives
174+
continue;
175+
}
176+
@Nullable ExecutableWork executableWork = workQueue.get(failedId.workId());
177+
if (executableWork == null) {
178+
continue;
194179
}
180+
executableWork.work().setFailed();
181+
LOG.debug(
182+
"Failing work {} {}. The work will be retried and is not lost.",
183+
computationStateCache.getComputation(),
184+
failedId);
195185
}
196186
}
197187

198-
/**
199-
* Returns a read only view of current active work.
200-
*
201-
* @implNote Do not return a reference to the underlying workQueue as iterations over it will
202-
* cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
203-
* structure.
204-
*/
205-
synchronized ImmutableListMultimap<ShardedKey, RefreshableWork> getReadOnlyActiveWork() {
206-
return activeWork.entrySet().stream()
207-
.collect(
208-
flatteningToImmutableListMultimap(
209-
Entry::getKey,
210-
e ->
211-
e.getValue().stream()
212-
.map(executableWork -> (RefreshableWork) executableWork.work())));
213-
}
214-
215188
synchronized ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
216189
return activeWork.values().stream()
217-
.flatMap(Deque::stream)
190+
.flatMap(workMap -> workMap.values().stream())
218191
.map(ExecutableWork::work)
219192
.filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline))
220193
.collect(toImmutableList());
@@ -236,7 +209,8 @@ private synchronized void decrementActiveWorkBudget(Work work) {
236209
*/
237210
synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
238211
ShardedKey shardedKey, WorkId workId) {
239-
@Nullable Queue<ExecutableWork> workQueue = activeWork.get(shardedKey);
212+
@Nullable
213+
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(shardedKey.shardingKey());
240214
if (workQueue == null) {
241215
// Work may have been completed due to clearing of stuck commits.
242216
LOG.warn(
@@ -251,14 +225,15 @@ synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
251225
}
252226

253227
private synchronized void removeCompletedWorkFromQueue(
254-
Queue<ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
255-
@Nullable ExecutableWork completedWork = workQueue.peek();
256-
if (completedWork == null) {
228+
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
229+
Iterator<Entry<WorkId, ExecutableWork>> completedWorkIterator = workQueue.entrySet().iterator();
230+
if (!completedWorkIterator.hasNext()) {
257231
// Work may have been completed due to clearing of stuck commits.
258232
LOG.warn("Active key {} without work, expected token {}", shardedKey, workId);
259233
return;
260234
}
261235

236+
ExecutableWork completedWork = completedWorkIterator.next().getValue();
262237
if (!completedWork.id().equals(workId)) {
263238
// Work may have been completed due to clearing of stuck commits.
264239
LOG.warn(
@@ -271,19 +246,18 @@ private synchronized void removeCompletedWorkFromQueue(
271246
completedWork.id());
272247
return;
273248
}
274-
275249
// We consumed the matching work item.
276-
workQueue.remove();
250+
completedWorkIterator.remove();
277251
decrementActiveWorkBudget(completedWork.work());
278252
}
279253

254+
@SuppressWarnings("ReferenceEquality")
280255
private synchronized Optional<ExecutableWork> getNextWork(
281-
Queue<ExecutableWork> workQueue, ShardedKey shardedKey) {
282-
Optional<ExecutableWork> nextWork = Optional.ofNullable(workQueue.peek());
256+
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey) {
257+
Optional<ExecutableWork> nextWork = Optional.ofNullable(firstValue(workQueue));
283258
if (!nextWork.isPresent()) {
284-
Preconditions.checkState(workQueue == activeWork.remove(shardedKey));
259+
Preconditions.checkState(workQueue == activeWork.remove(shardedKey.shardingKey()));
285260
}
286-
287261
return nextWork;
288262
}
289263

@@ -302,22 +276,26 @@ synchronized void invalidateStuckCommits(
302276
}
303277
}
304278

279+
private static @Nullable ExecutableWork firstValue(Map<WorkId, ExecutableWork> map) {
280+
Iterator<Entry<WorkId, ExecutableWork>> iterator = map.entrySet().iterator();
281+
return iterator.hasNext() ? iterator.next().getValue() : null;
282+
}
283+
305284
private synchronized ImmutableMap<ShardedKey, WorkId> getStuckCommitsAt(
306285
Instant stuckCommitDeadline) {
307286
// Determine the stuck commit keys but complete them outside the loop iterating over
308287
// activeWork as completeWork may delete the entry from activeWork.
309288
ImmutableMap.Builder<ShardedKey, WorkId> stuckCommits = ImmutableMap.builder();
310-
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
311-
ShardedKey shardedKey = entry.getKey();
312-
@Nullable ExecutableWork executableWork = entry.getValue().peek();
289+
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
290+
@Nullable ExecutableWork executableWork = firstValue(entry.getValue());
313291
if (executableWork != null) {
314292
Work work = executableWork.work();
315293
if (work.isStuckCommittingAt(stuckCommitDeadline)) {
316294
LOG.error(
317295
"Detected key {} stuck in COMMITTING state since {}, completing it with error.",
318-
shardedKey,
296+
work.getShardedKey(),
319297
work.getStateStartTime());
320-
stuckCommits.put(shardedKey, work.id());
298+
stuckCommits.put(work.getShardedKey(), work.id());
321299
}
322300
}
323301
}
@@ -353,9 +331,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
353331
// Use StringBuilder because we are appending in loop.
354332
StringBuilder activeWorkStatus = new StringBuilder();
355333
int commitsPendingCount = 0;
356-
for (Map.Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
357-
Queue<ExecutableWork> workQueue = Preconditions.checkNotNull(entry.getValue());
358-
Work activeWork = Preconditions.checkNotNull(workQueue.peek()).work();
334+
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
335+
LinkedHashMap<WorkId, ExecutableWork> workQueue =
336+
Preconditions.checkNotNull(entry.getValue());
337+
Work activeWork = Preconditions.checkNotNull(firstValue(workQueue)).work();
359338
WorkItem workItem = activeWork.getWorkItem();
360339
if (activeWork.isCommitPending()) {
361340
if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) {

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
3030
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
3131
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
32-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
3332
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
34-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
3533
import org.joda.time.Instant;
3634

3735
/**
@@ -120,7 +118,7 @@ public boolean activateWork(ExecutableWork executableWork) {
120118
}
121119
}
122120

123-
public void failWork(Multimap<Long, WorkId> failedWork) {
121+
public void failWork(ImmutableList<WorkIdWithShardingKey> failedWork) {
124122
activeWorkState.failWorkForKey(failedWork);
125123
}
126124

@@ -146,10 +144,6 @@ private void forceExecute(ExecutableWork executableWork) {
146144
executor.forceExecute(executableWork, executableWork.work().getSerializedWorkItemSize());
147145
}
148146

149-
public ImmutableListMultimap<ShardedKey, RefreshableWork> currentActiveWorkReadOnly() {
150-
return activeWorkState.getReadOnlyActiveWork();
151-
}
152-
153147
public ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
154148
return activeWorkState.getRefreshableWork(refreshDeadline);
155149
}

runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
2525
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse;
2626
import org.apache.beam.sdk.annotations.Internal;
27-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap;
28-
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
27+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
2928

3029
/**
3130
* Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that is invalid from
@@ -34,6 +33,7 @@
3433
@Internal
3534
public final class WorkHeartbeatResponseProcessor
3635
implements Consumer<List<ComputationHeartbeatResponse>> {
36+
3737
/** Fetches a {@link ComputationState} for a computationId. */
3838
private final Function<String, Optional<ComputationState>> computationStateFetcher;
3939

@@ -46,23 +46,23 @@ public WorkHeartbeatResponseProcessor(
4646
@Override
4747
public void accept(List<ComputationHeartbeatResponse> responses) {
4848
for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) {
49-
// Maps sharding key to (work token, cache token) for work that should be marked failed.
50-
Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
49+
ImmutableList.Builder<WorkIdWithShardingKey> failedWorkBuilder = ImmutableList.builder();
5150
for (HeartbeatResponse heartbeatResponse :
5251
computationHeartbeatResponse.getHeartbeatResponsesList()) {
5352
if (heartbeatResponse.getFailed()) {
54-
failedWork.put(
55-
heartbeatResponse.getShardingKey(),
53+
WorkId workId =
5654
WorkId.builder()
5755
.setWorkToken(heartbeatResponse.getWorkToken())
5856
.setCacheToken(heartbeatResponse.getCacheToken())
59-
.build());
57+
.build();
58+
failedWorkBuilder.add(
59+
WorkIdWithShardingKey.create(heartbeatResponse.getShardingKey(), workId));
6060
}
6161
}
6262

6363
computationStateFetcher
6464
.apply(computationHeartbeatResponse.getComputationId())
65-
.ifPresent(state -> state.failWork(failedWork));
65+
.ifPresent(state -> state.failWork(failedWorkBuilder.build()));
6666
}
6767
}
6868
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.runners.dataflow.worker.streaming;
19+
20+
import com.google.auto.value.AutoValue;
21+
22+
@AutoValue
23+
abstract class WorkIdWithShardingKey {
24+
25+
public static WorkIdWithShardingKey create(long shardingKey, WorkId workId) {
26+
return new AutoValue_WorkIdWithShardingKey(shardingKey, workId);
27+
}
28+
29+
public abstract long shardingKey();
30+
31+
public abstract WorkId workId();
32+
}

0 commit comments

Comments
 (0)