1818package org .apache .beam .runners .dataflow .worker .streaming ;
1919
2020import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList .toImmutableList ;
21- import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableListMultimap .flatteningToImmutableListMultimap ;
2221
2322import java .io .PrintWriter ;
24- import java .util .ArrayDeque ;
25- import java .util .Collection ;
26- import java .util .Deque ;
2723import java .util .HashMap ;
2824import java .util .Iterator ;
25+ import java .util .LinkedHashMap ;
2926import java .util .Map ;
3027import java .util .Map .Entry ;
3128import java .util .Optional ;
3633import javax .annotation .concurrent .ThreadSafe ;
3734import org .apache .beam .runners .dataflow .worker .windmill .Windmill .WorkItem ;
3835import org .apache .beam .runners .dataflow .worker .windmill .state .WindmillStateCache ;
36+ import org .apache .beam .runners .dataflow .worker .windmill .state .WindmillStateCache .ForComputation ;
3937import org .apache .beam .runners .dataflow .worker .windmill .work .budget .GetWorkBudget ;
4038import org .apache .beam .sdk .annotations .Internal ;
4139import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .annotations .VisibleForTesting ;
4240import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Preconditions ;
4341import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableList ;
44- import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableListMultimap ;
4542import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .ImmutableMap ;
46- import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .collect .Multimap ;
4743import org .joda .time .Duration ;
4844import org .joda .time .Instant ;
4945import org .slf4j .Logger ;
@@ -63,11 +59,11 @@ public final class ActiveWorkState {
6359 private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50 ;
6460
6561 /**
66- * Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link
67- * Queue<Work>} is actively processing.
62+ * Map from shardingKey to {@link Work} for the key. The first item in the {@link LinkedHashMap}
63+ * is actively processing.
6864 */
6965 @ GuardedBy ("this" )
70- private final Map <ShardedKey , Deque < ExecutableWork >> activeWork ;
66+ private final Map <Long /*shardingKey*/ , LinkedHashMap < WorkId , ExecutableWork >> activeWork ;
7167
7268 @ GuardedBy ("this" )
7369 private final WindmillStateCache .ForComputation computationStateCache ;
@@ -81,8 +77,8 @@ public final class ActiveWorkState {
8177 private GetWorkBudget activeGetWorkBudget ;
8278
8379 private ActiveWorkState (
84- Map <ShardedKey , Deque < ExecutableWork >> activeWork ,
85- WindmillStateCache . ForComputation computationStateCache ) {
80+ Map <Long , LinkedHashMap < WorkId , ExecutableWork >> activeWork ,
81+ ForComputation computationStateCache ) {
8682 this .activeWork = activeWork ;
8783 this .computationStateCache = computationStateCache ;
8884 this .activeGetWorkBudget = GetWorkBudget .noBudget ();
@@ -94,7 +90,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState
9490
9591 @ VisibleForTesting
9692 static ActiveWorkState forTesting (
97- Map <ShardedKey , Deque < ExecutableWork >> activeWork ,
93+ Map <Long , LinkedHashMap < WorkId , ExecutableWork >> activeWork ,
9894 WindmillStateCache .ForComputation computationStateCache ) {
9995 return new ActiveWorkState (activeWork , computationStateCache );
10096 }
@@ -124,28 +120,30 @@ private static String elapsedString(Instant start, Instant end) {
124120 */
125121 synchronized ActivateWorkResult activateWorkForKey (ExecutableWork executableWork ) {
126122 ShardedKey shardedKey = executableWork .work ().getShardedKey ();
127- Deque <ExecutableWork > workQueue = activeWork .getOrDefault (shardedKey , new ArrayDeque <>());
123+ long shardingKey = shardedKey .shardingKey ();
124+ LinkedHashMap <WorkId , ExecutableWork > workQueue =
125+ activeWork .computeIfAbsent (shardingKey , (unused ) -> new LinkedHashMap <>());
128126 // This key does not have any work queued up on it. Create one, insert Work, and mark the work
129127 // to be executed.
130- if (!activeWork .containsKey (shardedKey ) || workQueue .isEmpty ()) {
131- workQueue .addLast (executableWork );
132- activeWork .put (shardedKey , workQueue );
128+ if (workQueue .isEmpty ()) {
129+ workQueue .put (executableWork .id (), executableWork );
133130 incrementActiveWorkBudget (executableWork .work ());
134131 return ActivateWorkResult .EXECUTE ;
135132 }
136133
137134 // Check to see if we have this work token queued.
138- Iterator <ExecutableWork > workIterator = workQueue .iterator ();
135+ Iterator <Entry < WorkId , ExecutableWork >> workIterator = workQueue . entrySet () .iterator ();
139136 while (workIterator .hasNext ()) {
140- ExecutableWork queuedWork = workIterator .next ();
137+ ExecutableWork queuedWork = workIterator .next (). getValue () ;
141138 if (queuedWork .id ().equals (executableWork .id ())) {
142139 return ActivateWorkResult .DUPLICATE ;
143140 }
144- if (queuedWork .id ().cacheToken () == executableWork .id ().cacheToken ()) {
141+ if (queuedWork .id ().cacheToken () == executableWork .id ().cacheToken ()
142+ && queuedWork .work ().getShardedKey ().equals (executableWork .work ().getShardedKey ())) {
145143 if (executableWork .id ().workToken () > queuedWork .id ().workToken ()) {
146144 // Check to see if the queuedWork is active. We only want to remove it if it is NOT
147145 // currently active.
148- if (!queuedWork .equals (workQueue . peek ( ))) {
146+ if (!queuedWork .equals (Preconditions . checkNotNull ( firstValue ( workQueue ) ))) {
149147 workIterator .remove ();
150148 decrementActiveWorkBudget (queuedWork .work ());
151149 }
@@ -157,7 +155,7 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
157155 }
158156
159157 // Queue the work for later processing.
160- workQueue .addLast ( executableWork );
158+ workQueue .put ( executableWork . id (), executableWork );
161159 incrementActiveWorkBudget (executableWork .work ());
162160 return ActivateWorkResult .QUEUED ;
163161 }
@@ -167,54 +165,29 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
167165 *
168166 * @param failedWork a map from sharding_key to tokens for the corresponding work.
169167 */
170- synchronized void failWorkForKey (Multimap <Long , WorkId > failedWork ) {
171- // Note we can't construct a ShardedKey and look it up in activeWork directly since
172- // HeartbeatResponse doesn't include the user key.
173- for (Entry <ShardedKey , Deque <ExecutableWork >> entry : activeWork .entrySet ()) {
174- Collection <WorkId > failedWorkIds = failedWork .get (entry .getKey ().shardingKey ());
175- for (WorkId failedWorkId : failedWorkIds ) {
176- for (ExecutableWork queuedWork : entry .getValue ()) {
177- WorkItem workItem = queuedWork .work ().getWorkItem ();
178- if (workItem .getWorkToken () == failedWorkId .workToken ()
179- && workItem .getCacheToken () == failedWorkId .cacheToken ()) {
180- LOG .debug (
181- "Failing work "
182- + computationStateCache .getComputation ()
183- + " "
184- + entry .getKey ().shardingKey ()
185- + " "
186- + failedWorkId .workToken ()
187- + " "
188- + failedWorkId .cacheToken ()
189- + ". The work will be retried and is not lost." );
190- queuedWork .work ().setFailed ();
191- break ;
192- }
193- }
168+ synchronized void failWorkForKey (ImmutableList <WorkIdWithShardingKey > failedWork ) {
169+ for (WorkIdWithShardingKey failedId : failedWork ) {
170+ @ Nullable
171+ LinkedHashMap <WorkId , ExecutableWork > workQueue = activeWork .get (failedId .shardingKey ());
172+ if (workQueue == null ) {
173+ // Work could complete/fail before heartbeat response arrives
174+ continue ;
175+ }
176+ @ Nullable ExecutableWork executableWork = workQueue .get (failedId .workId ());
177+ if (executableWork == null ) {
178+ continue ;
194179 }
180+ executableWork .work ().setFailed ();
181+ LOG .debug (
182+ "Failing work {} {}. The work will be retried and is not lost." ,
183+ computationStateCache .getComputation (),
184+ failedId );
195185 }
196186 }
197187
198- /**
199- * Returns a read only view of current active work.
200- *
201- * @implNote Do not return a reference to the underlying workQueue as iterations over it will
202- * cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
203- * structure.
204- */
205- synchronized ImmutableListMultimap <ShardedKey , RefreshableWork > getReadOnlyActiveWork () {
206- return activeWork .entrySet ().stream ()
207- .collect (
208- flatteningToImmutableListMultimap (
209- Entry ::getKey ,
210- e ->
211- e .getValue ().stream ()
212- .map (executableWork -> (RefreshableWork ) executableWork .work ())));
213- }
214-
215188 synchronized ImmutableList <RefreshableWork > getRefreshableWork (Instant refreshDeadline ) {
216189 return activeWork .values ().stream ()
217- .flatMap (Deque :: stream )
190+ .flatMap (workMap -> workMap . values (). stream () )
218191 .map (ExecutableWork ::work )
219192 .filter (work -> !work .isFailed () && work .getStartTime ().isBefore (refreshDeadline ))
220193 .collect (toImmutableList ());
@@ -236,7 +209,8 @@ private synchronized void decrementActiveWorkBudget(Work work) {
236209 */
237210 synchronized Optional <ExecutableWork > completeWorkAndGetNextWorkForKey (
238211 ShardedKey shardedKey , WorkId workId ) {
239- @ Nullable Queue <ExecutableWork > workQueue = activeWork .get (shardedKey );
212+ @ Nullable
213+ LinkedHashMap <WorkId , ExecutableWork > workQueue = activeWork .get (shardedKey .shardingKey ());
240214 if (workQueue == null ) {
241215 // Work may have been completed due to clearing of stuck commits.
242216 LOG .warn (
@@ -251,14 +225,15 @@ synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
251225 }
252226
253227 private synchronized void removeCompletedWorkFromQueue (
254- Queue < ExecutableWork > workQueue , ShardedKey shardedKey , WorkId workId ) {
255- @ Nullable ExecutableWork completedWork = workQueue .peek ();
256- if (completedWork == null ) {
228+ LinkedHashMap < WorkId , ExecutableWork > workQueue , ShardedKey shardedKey , WorkId workId ) {
229+ Iterator < Entry < WorkId , ExecutableWork >> completedWorkIterator = workQueue .entrySet (). iterator ();
230+ if (! completedWorkIterator . hasNext () ) {
257231 // Work may have been completed due to clearing of stuck commits.
258232 LOG .warn ("Active key {} without work, expected token {}" , shardedKey , workId );
259233 return ;
260234 }
261235
236+ ExecutableWork completedWork = completedWorkIterator .next ().getValue ();
262237 if (!completedWork .id ().equals (workId )) {
263238 // Work may have been completed due to clearing of stuck commits.
264239 LOG .warn (
@@ -271,19 +246,18 @@ private synchronized void removeCompletedWorkFromQueue(
271246 completedWork .id ());
272247 return ;
273248 }
274-
275249 // We consumed the matching work item.
276- workQueue .remove ();
250+ completedWorkIterator .remove ();
277251 decrementActiveWorkBudget (completedWork .work ());
278252 }
279253
254+ @ SuppressWarnings ("ReferenceEquality" )
280255 private synchronized Optional <ExecutableWork > getNextWork (
281- Queue < ExecutableWork > workQueue , ShardedKey shardedKey ) {
282- Optional <ExecutableWork > nextWork = Optional .ofNullable (workQueue . peek ( ));
256+ LinkedHashMap < WorkId , ExecutableWork > workQueue , ShardedKey shardedKey ) {
257+ Optional <ExecutableWork > nextWork = Optional .ofNullable (firstValue ( workQueue ));
283258 if (!nextWork .isPresent ()) {
284- Preconditions .checkState (workQueue == activeWork .remove (shardedKey ));
259+ Preconditions .checkState (workQueue == activeWork .remove (shardedKey . shardingKey () ));
285260 }
286-
287261 return nextWork ;
288262 }
289263
@@ -302,22 +276,26 @@ synchronized void invalidateStuckCommits(
302276 }
303277 }
304278
279+ private static @ Nullable ExecutableWork firstValue (Map <WorkId , ExecutableWork > map ) {
280+ Iterator <Entry <WorkId , ExecutableWork >> iterator = map .entrySet ().iterator ();
281+ return iterator .hasNext () ? iterator .next ().getValue () : null ;
282+ }
283+
305284 private synchronized ImmutableMap <ShardedKey , WorkId > getStuckCommitsAt (
306285 Instant stuckCommitDeadline ) {
307286 // Determine the stuck commit keys but complete them outside the loop iterating over
308287 // activeWork as completeWork may delete the entry from activeWork.
309288 ImmutableMap .Builder <ShardedKey , WorkId > stuckCommits = ImmutableMap .builder ();
310- for (Entry <ShardedKey , Deque <ExecutableWork >> entry : activeWork .entrySet ()) {
311- ShardedKey shardedKey = entry .getKey ();
312- @ Nullable ExecutableWork executableWork = entry .getValue ().peek ();
289+ for (Entry <Long , LinkedHashMap <WorkId , ExecutableWork >> entry : activeWork .entrySet ()) {
290+ @ Nullable ExecutableWork executableWork = firstValue (entry .getValue ());
313291 if (executableWork != null ) {
314292 Work work = executableWork .work ();
315293 if (work .isStuckCommittingAt (stuckCommitDeadline )) {
316294 LOG .error (
317295 "Detected key {} stuck in COMMITTING state since {}, completing it with error." ,
318- shardedKey ,
296+ work . getShardedKey () ,
319297 work .getStateStartTime ());
320- stuckCommits .put (shardedKey , work .id ());
298+ stuckCommits .put (work . getShardedKey () , work .id ());
321299 }
322300 }
323301 }
@@ -353,9 +331,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
353331 // Use StringBuilder because we are appending in loop.
354332 StringBuilder activeWorkStatus = new StringBuilder ();
355333 int commitsPendingCount = 0 ;
356- for (Map .Entry <ShardedKey , Deque <ExecutableWork >> entry : activeWork .entrySet ()) {
357- Queue <ExecutableWork > workQueue = Preconditions .checkNotNull (entry .getValue ());
358- Work activeWork = Preconditions .checkNotNull (workQueue .peek ()).work ();
334+ for (Entry <Long , LinkedHashMap <WorkId , ExecutableWork >> entry : activeWork .entrySet ()) {
335+ LinkedHashMap <WorkId , ExecutableWork > workQueue =
336+ Preconditions .checkNotNull (entry .getValue ());
337+ Work activeWork = Preconditions .checkNotNull (firstValue (workQueue )).work ();
359338 WorkItem workItem = activeWork .getWorkItem ();
360339 if (activeWork .isCommitPending ()) {
361340 if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS ) {
0 commit comments