diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index 85fa1d67c6c3..b68f53121b86 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -45,6 +45,7 @@ @Internal @ThreadSafe public final class StreamingEngineWorkCommitter implements WorkCommitter { + private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; private static final String NO_BACKEND_WORKER_TOKEN = ""; @@ -99,19 +100,23 @@ public void start() { @Override public void commit(Commit commit) { - boolean isShutdown = !this.isRunning.get(); - if (commit.work().isFailed() || isShutdown) { - if (isShutdown) { - LOG.debug( - "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}, workId={} ].", - commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); - } + if (commit.work().isFailed()) { failCommit(commit); } else { commitQueue.put(commit); } + + // Do this check after adding to commitQueue, else commitQueue.put() can race with + // drainCommitQueue() in stop() and leave commits orphaned in the queue. + if (!this.isRunning.get()) { + LOG.debug( + "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}," + + " workId={} ].", + commit.computationId(), + commit.work().getShardedKey(), + commit.work().id()); + drainCommitQueue(); + } } @Override @@ -255,6 +260,7 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch @AutoBuilder public interface Builder { + Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index b4f63fa71618..01197622c24d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -34,13 +34,21 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.runners.dataflow.worker.FakeWindmillServer; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -67,6 +75,7 @@ @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @@ -75,9 +84,17 @@ public class StreamingEngineWorkCommitterTest { private Supplier> commitWorkStreamFactory; private static void waitForExpectedSetSize(Set s, int expectedSize) { + long deadline = System.currentTimeMillis() + 100 * 1000; // 100 seconds while (s.size() < expectedSize) { try { Thread.sleep(10); + if (System.currentTimeMillis() > deadline) { + throw new RuntimeException( + "Timed out waiting for expected set size to be: " + + expectedSize + + " but was: " + + s.size()); + } } catch (InterruptedException e) { throw new RuntimeException(e); } @@ -400,4 +417,61 @@ public void testMultipleCommitSendersSingleStream() { workCommitter.stop(); } + + @Test + public void testStop_drainsCommitQueue_concurrentCommit() + throws InterruptedException, ExecutionException, TimeoutException { + Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); + workCommitter = + StreamingEngineWorkCommitter.builder() + // Set the semaphore to only allow a single commit at a time. + // This creates a bottleneck on purpose to trigger race conditions during shutdown. + .setCommitByteSemaphore(WeightedSemaphore.create(1, (commit) -> 1)) + .setCommitWorkStreamFactory(commitWorkStreamFactory) + .setOnCommitComplete(completeCommits::add) + .build(); + + int numThreads = 5; + ExecutorService producer = Executors.newFixedThreadPool(numThreads); + AtomicBoolean producing = new AtomicBoolean(true); + AtomicLong sentCommits = new AtomicLong(0); + + workCommitter.start(); + + AtomicLong workToken = new AtomicLong(0); + List> futures = new ArrayList<>(numThreads); + for (int i = 0; i < numThreads; i++) { + futures.add( + producer.submit( + () -> { + while (producing.get()) { + Work work = createMockWork(workToken.getAndIncrement()); + WorkItemCommitRequest commitRequest = + WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + Commit commit = + Commit.create(commitRequest, createComputationState("computationId"), work); + workCommitter.commit(commit); + sentCommits.incrementAndGet(); + } + })); + } + + // Let it run for a bit + Thread.sleep(100); + + workCommitter.stop(); + producing.set(false); + producer.shutdown(); + assertTrue(producer.awaitTermination(10, TimeUnit.SECONDS)); + for (Future future : futures) { + future.get(10, TimeUnit.SECONDS); + } + + waitForExpectedSetSize(completeCommits, sentCommits.intValue()); + } }