diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index 24fc17d4c74a..743ee4b948ff 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -4,5 +4,6 @@ "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should run this test", "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test", - "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface" + "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface", + "https://github.com/apache/beam/pull/36631": "dofn lifecycle", } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json index 24fc17d4c74a..47d924953c51 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json @@ -4,5 +4,6 @@ "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should run this test", "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test", - "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface" + "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface", + "https://github.com/apache/beam/pull/36631": "dofn lifecycle validation", } diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index b4ba32c1cc95..415132fa7d2c 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -205,7 +205,6 @@ def commonLegacyExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesMultimapState', 'org.apache.beam.sdk.testing.UsesTestStream', - 'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote runner 'org.apache.beam.sdk.testing.UsesMetricsPusher', 'org.apache.beam.sdk.testing.UsesBundleFinalizer', 'org.apache.beam.sdk.testing.UsesBoundedTrieMetrics', // Dataflow QM as of now does not support returning back BoundedTrie in metric result. @@ -452,7 +451,17 @@ task validatesRunner { excludedTests: [ // TODO(https://github.com/apache/beam/issues/21472) 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', - ] + + // These tests use static state and don't work with remote execution. + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful', + ] )) } @@ -474,7 +483,17 @@ task validatesRunnerStreaming { // GroupIntoBatches.withShardedKey not supported on streaming runner v1 // https://github.com/apache/beam/issues/22592 'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow', - ] + + // These tests use static state and don't work with remote execution. + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful', +] )) } @@ -543,8 +562,7 @@ task validatesRunnerV2 { excludedTests: [ 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshuffleWithTimestampsStreaming', - // TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle. - 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testFnCallSequenceStateful', + // These tests use static state and don't work with remote execution. 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', @@ -586,7 +604,7 @@ task validatesRunnerV2Streaming { 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', 'org.apache.beam.sdk.transforms.GroupByKeyTest.testCombiningAccumulatingProcessingTime', - // TODO(https://github.com/apache/beam/issues/18592): respect ParDo lifecycle. + // These tests use static state and don't work with remote execution. 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java index 91fb640a1757..d3f2aacc74d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java @@ -105,11 +105,32 @@ public DataflowMapTaskExecutor create( Networks.replaceDirectedNetworkNodes( network, createOutputReceiversTransform(stageName, counterSet)); - // Swap out all the ParallelInstruction nodes with Operation nodes - Networks.replaceDirectedNetworkNodes( - network, - createOperationTransformForParallelInstructionNodes( - stageName, network, options, readerFactory, sinkFactory, executionContext)); + // Swap out all the ParallelInstruction nodes with Operation nodes. While updating the network, + // we keep track of + // the created Operations so that if an exception is encountered we can properly abort started + // operations. + ArrayList createdOperations = new ArrayList<>(); + try { + Networks.replaceDirectedNetworkNodes( + network, + createOperationTransformForParallelInstructionNodes( + stageName, + network, + options, + readerFactory, + sinkFactory, + executionContext, + createdOperations)); + } catch (RuntimeException exn) { + for (Operation o : createdOperations) { + try { + o.abort(); + } catch (Exception exn2) { + exn.addSuppressed(exn2); + } + } + throw exn; + } // Collect all the operations within the network and attach all the operations as receivers // to preceding output receivers. @@ -144,7 +165,8 @@ Function createOperationTransformForParallelInstructionNodes( final PipelineOptions options, final ReaderFactory readerFactory, final SinkFactory sinkFactory, - final DataflowExecutionContext executionContext) { + final DataflowExecutionContext executionContext, + final List createdOperations) { return new TypeSafeNodeFunction(ParallelInstructionNode.class) { @Override @@ -156,20 +178,22 @@ public Node typedApply(ParallelInstructionNode node) { instruction.getOriginalName(), instruction.getSystemName(), instruction.getName()); + OperationNode result; try { DataflowOperationContext context = executionContext.createOperationContext(nameContext); if (instruction.getRead() != null) { - return createReadOperation( - network, node, options, readerFactory, executionContext, context); + result = + createReadOperation( + network, node, options, readerFactory, executionContext, context); } else if (instruction.getWrite() != null) { - return createWriteOperation(node, options, sinkFactory, executionContext, context); + result = createWriteOperation(node, options, sinkFactory, executionContext, context); } else if (instruction.getParDo() != null) { - return createParDoOperation(network, node, options, executionContext, context); + result = createParDoOperation(network, node, options, executionContext, context); } else if (instruction.getPartialGroupByKey() != null) { - return createPartialGroupByKeyOperation( - network, node, options, executionContext, context); + result = + createPartialGroupByKeyOperation(network, node, options, executionContext, context); } else if (instruction.getFlatten() != null) { - return createFlattenOperation(network, node, context); + result = createFlattenOperation(network, node, context); } else { throw new IllegalArgumentException( String.format("Unexpected instruction: %s", instruction)); @@ -177,6 +201,8 @@ public Node typedApply(ParallelInstructionNode node) { } catch (Exception e) { throw new RuntimeException(e); } + createdOperations.add(result.getOperation()); + return result; } }; } @@ -328,7 +354,6 @@ public Node typedApply(InstructionOutputNode input) { Coder coder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudOutput.getCodec())); - @SuppressWarnings("unchecked") ElementCounter outputCounter = new DataflowOutputCounter( cloudOutput.getName(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java index 877e3198e91d..58b95f286d55 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java @@ -18,8 +18,8 @@ package org.apache.beam.runners.dataflow.worker.util.common.worker; import java.io.Closeable; +import java.util.ArrayList; import java.util.List; -import java.util.ListIterator; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; @@ -36,7 +36,9 @@ public class MapTaskExecutor implements WorkExecutor { private static final Logger LOG = LoggerFactory.getLogger(MapTaskExecutor.class); /** The operations in the map task, in execution order. */ - public final List operations; + public final ArrayList operations; + + private boolean closed = false; private final ExecutionStateTracker executionStateTracker; @@ -54,7 +56,7 @@ public MapTaskExecutor( CounterSet counters, ExecutionStateTracker executionStateTracker) { this.counters = counters; - this.operations = operations; + this.operations = new ArrayList<>(operations); this.executionStateTracker = executionStateTracker; } @@ -63,6 +65,7 @@ public CounterSet getOutputCounters() { return counters; } + /** May be reused if execute() returns without an exception being thrown. */ @Override public void execute() throws Exception { LOG.debug("Executing map task"); @@ -74,13 +77,11 @@ public void execute() throws Exception { // Starting a root operation such as a ReadOperation does the work // of processing the input dataset. LOG.debug("Starting operations"); - ListIterator iterator = operations.listIterator(operations.size()); - while (iterator.hasPrevious()) { + for (int i = operations.size() - 1; i >= 0; --i) { if (Thread.currentThread().isInterrupted()) { throw new InterruptedException("Worker aborted"); } - Operation op = iterator.previous(); - op.start(); + operations.get(i).start(); } // Finish operations, in forward-execution-order, so that a @@ -94,16 +95,13 @@ public void execute() throws Exception { op.finish(); } } catch (Exception | Error exn) { - LOG.debug("Aborting operations", exn); - for (Operation op : operations) { - try { - op.abort(); - } catch (Exception | Error exn2) { - exn.addSuppressed(exn2); - if (exn2 instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - } + try { + closeInternal(); + } catch (Exception closeExn) { + exn.addSuppressed(closeExn); + } + if (exn instanceof InterruptedException) { + Thread.currentThread().interrupt(); } throw exn; } @@ -164,6 +162,45 @@ public void abort() { } } + private void closeInternal() throws Exception { + if (closed) { + return; + } + LOG.debug("Aborting operations"); + @Nullable Exception exn = null; + for (Operation op : operations) { + try { + op.abort(); + } catch (Exception | Error exn2) { + if (exn2 instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (exn == null) { + if (exn2 instanceof Exception) { + exn = (Exception) exn2; + } else { + exn = new RuntimeException(exn2); + } + } else { + exn.addSuppressed(exn2); + } + } + } + closed = true; + if (exn != null) { + throw exn; + } + } + + @Override + public void close() { + try { + closeInternal(); + } catch (Exception e) { + LOG.error("Exception while closing MapTaskExecutor, ignoring", e); + } + } + @Override public List reportProducedEmptyOutput() { List emptyOutputSinkIndexes = Lists.newArrayList(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index a4cd5d6d8a6b..e61c2d1f4a03 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -415,6 +415,7 @@ private ExecuteWorkResult executeWork( // Release the execution state for another thread to use. computationState.releaseComputationWorkExecutor(computationWorkExecutor); + computationWorkExecutor = null; work.setState(Work.State.COMMIT_QUEUED); outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler)); @@ -422,11 +423,13 @@ private ExecuteWorkResult executeWork( return ExecuteWorkResult.create( outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); } catch (Throwable t) { - // If processing failed due to a thrown exception, close the executionState. Do not - // return/release the executionState back to computationState as that will lead to this - // executionState instance being reused. - LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); - computationWorkExecutor.invalidate(); + if (computationWorkExecutor != null) { + // If processing failed due to a thrown exception, close the executionState. Do not + // return/release the executionState back to computationState as that will lead to this + // executionState instance being reused. + LOG.debug("Invalidating executor after work item {} failed", workItem.getWorkToken(), t); + computationWorkExecutor.invalidate(); + } // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. throw t; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java index e77ae309d359..3443ae0022bc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java @@ -24,11 +24,16 @@ import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray; import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -52,6 +57,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; @@ -254,8 +260,9 @@ public void testExecutionContextPlumbing() throws Exception { List instructions = Arrays.asList( createReadInstruction("Read", ReaderFactoryTest.SingletonTestReaderFactory.class), - createParDoInstruction(0, 0, "DoFn1", "DoFnUserName"), - createParDoInstruction(1, 0, "DoFnWithContext", "DoFnWithContextUserName")); + createParDoInstruction(0, 0, "DoFn1", "DoFnUserName", new TestDoFn()), + createParDoInstruction( + 1, 0, "DoFnWithContext", "DoFnWithContextUserName", new TestDoFn())); MapTask mapTask = new MapTask(); mapTask.setStageName(STAGE); @@ -330,6 +337,7 @@ public void testCreateReadOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -338,11 +346,13 @@ public void testCreateReadOperation() throws Exception { PipelineOptionsFactory.create(), readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ReadOperation.class)); ReadOperation readOperation = (ReadOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(readOperation)); assertEquals(1, readOperation.receivers.length); assertEquals(0, readOperation.receivers[0].getReceiverCount()); @@ -391,6 +401,7 @@ public void testCreateWriteOperation() throws Exception { ParallelInstructionNode.create( createWriteInstruction(producerIndex, producerOutputNum, "WriteOperation"), ExecutionLocation.UNKNOWN); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -399,11 +410,13 @@ public void testCreateWriteOperation() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(WriteOperation.class)); WriteOperation writeOperation = (WriteOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(writeOperation)); assertEquals(0, writeOperation.receivers.length); assertEquals(Operation.InitializationState.UNSTARTED, writeOperation.initializationState); @@ -461,17 +474,15 @@ public TestSink create( static ParallelInstruction createParDoInstruction( int producerIndex, int producerOutputNum, String systemName) { - return createParDoInstruction(producerIndex, producerOutputNum, systemName, ""); + return createParDoInstruction(producerIndex, producerOutputNum, systemName, "", new TestDoFn()); } static ParallelInstruction createParDoInstruction( - int producerIndex, int producerOutputNum, String systemName, String userName) { + int producerIndex, int producerOutputNum, String systemName, String userName, DoFn fn) { InstructionInput cloudInput = new InstructionInput(); cloudInput.setProducerInstructionIndex(producerIndex); cloudInput.setOutputNum(producerOutputNum); - TestDoFn fn = new TestDoFn(); - String serializedFn = StringUtils.byteArrayToJsonString( SerializableUtils.serializeToByteArray( @@ -541,14 +552,16 @@ public void testCreateParDoOperation() throws Exception { .getMultiOutputInfos() .get(0)))); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( - STAGE, network, options, readerRegistry, sinkRegistry, context) + STAGE, network, options, readerRegistry, sinkRegistry, context, createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(parDoOperation)); assertEquals(1, parDoOperation.receivers.length); assertEquals(0, parDoOperation.receivers[0].getReceiverCount()); @@ -608,6 +621,7 @@ public void testCreatePartialGroupByKeyOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -616,11 +630,13 @@ public void testCreatePartialGroupByKeyOperation() throws Exception { PipelineOptionsFactory.create(), readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(pgbkOperation)); assertEquals(1, pgbkOperation.receivers.length); assertEquals(0, pgbkOperation.receivers[0].getReceiverCount()); @@ -660,6 +676,7 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -668,11 +685,13 @@ public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class)); ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(pgbkOperation)); assertEquals(1, pgbkOperation.receivers.length); assertEquals(0, pgbkOperation.receivers[0].getReceiverCount()); @@ -738,6 +757,7 @@ public void testCreateFlattenOperation() throws Exception { PCOLLECTION_ID)))); when(network.outDegree(instructionNode)).thenReturn(1); + ArrayList createdOperations = new ArrayList<>(); Node operationNode = mapTaskExecutorFactory .createOperationTransformForParallelInstructionNodes( @@ -746,15 +766,108 @@ public void testCreateFlattenOperation() throws Exception { options, readerRegistry, sinkRegistry, - BatchModeExecutionContext.forTesting(options, counterSet, "testStage")) + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + createdOperations) .apply(instructionNode); assertThat(operationNode, instanceOf(OperationNode.class)); assertThat(((OperationNode) operationNode).getOperation(), instanceOf(FlattenOperation.class)); FlattenOperation flattenOperation = (FlattenOperation) ((OperationNode) operationNode).getOperation(); + assertThat(createdOperations, contains(flattenOperation)); assertEquals(1, flattenOperation.receivers.length); assertEquals(0, flattenOperation.receivers[0].getReceiverCount()); assertEquals(Operation.InitializationState.UNSTARTED, flattenOperation.initializationState); } + + static class TestTeardownDoFn extends DoFn { + static AtomicInteger setupCalls = new AtomicInteger(); + static AtomicInteger teardownCalls = new AtomicInteger(); + + private final boolean throwExceptionOnSetup; + private boolean setupCalled = false; + + TestTeardownDoFn(boolean throwExceptionOnSetup) { + this.throwExceptionOnSetup = throwExceptionOnSetup; + } + + @Setup + public void setup() { + assertFalse(setupCalled); + setupCalled = true; + setupCalls.addAndGet(1); + if (throwExceptionOnSetup) { + throw new RuntimeException("Test setup exception"); + } + } + + @ProcessElement + public void process(ProcessContext c) { + fail("no elements should be processed"); + } + + @Teardown + public void teardown() { + assertTrue(setupCalled); + setupCalled = false; + teardownCalls.addAndGet(1); + } + } + + @Test + public void testCreateMapTaskExecutorException() throws Exception { + List instructions = + Arrays.asList( + createReadInstruction("Read"), + createParDoInstruction(0, 0, "DoFn1", "DoFn1", new TestTeardownDoFn(false)), + createParDoInstruction(0, 0, "DoFn2", "DoFn2", new TestTeardownDoFn(false)), + createParDoInstruction(0, 0, "ErrorFn", "", new TestTeardownDoFn(true)), + createParDoInstruction(0, 0, "DoFn3", "DoFn3", new TestTeardownDoFn(false)), + createFlattenInstruction(1, 0, 2, 0, "Flatten"), + createWriteInstruction(3, 0, "Write")); + + MapTask mapTask = new MapTask(); + mapTask.setStageName(STAGE); + mapTask.setSystemName("systemName"); + mapTask.setInstructions(instructions); + mapTask.setFactory(Transport.getJsonFactory()); + + assertThrows( + "Test setup exception", + RuntimeException.class, + () -> + mapTaskExecutorFactory.create( + mapTaskToNetwork.apply(mapTask), + options, + STAGE, + readerRegistry, + sinkRegistry, + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + counterSet, + idGenerator)); + assertEquals(3, TestTeardownDoFn.setupCalls.getAndSet(0)); + // We only tear-down the instruction we were unable to create. The other + // infos are cached within UserParDoFnFactory and not torn-down. + assertEquals(1, TestTeardownDoFn.teardownCalls.getAndSet(0)); + + assertThrows( + "Test setup exception", + RuntimeException.class, + () -> + mapTaskExecutorFactory.create( + mapTaskToNetwork.apply(mapTask), + options, + STAGE, + readerRegistry, + sinkRegistry, + BatchModeExecutionContext.forTesting(options, counterSet, "testStage"), + counterSet, + idGenerator)); + // The non-erroring functions are cached, and a new setup call is called on + // erroring dofn. + assertEquals(1, TestTeardownDoFn.setupCalls.get()); + // We only tear-down the instruction we were unable to create. The other + // infos are cached within UserParDoFnFactory and not torn-down. + assertEquals(1, TestTeardownDoFn.teardownCalls.get()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java index bb92fca3d8be..9e45425562a3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java @@ -198,7 +198,7 @@ public void testOutputReceivers() throws Exception { new TestDoFn( ImmutableList.of( new TupleTag<>("tag1"), new TupleTag<>("tag2"), new TupleTag<>("tag3"))); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -279,7 +279,7 @@ public void testOutputReceivers() throws Exception { @SuppressWarnings("AssertionFailureIgnored") public void testUnexpectedNumberOfReceivers() throws Exception { TestDoFn fn = new TestDoFn(Collections.emptyList()); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -330,7 +330,7 @@ private List stackTraceFrameStrings(Throwable t) { @Test public void testErrorPropagation() throws Exception { TestErrorDoFn fn = new TestErrorDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -423,7 +423,7 @@ public void testUndeclaredSideOutputs() throws Exception { new TupleTag<>("undecl1"), new TupleTag<>("undecl2"), new TupleTag<>("undecl3"))); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -485,7 +485,7 @@ public void processElement(ProcessContext c) throws Exception { } StateTestingDoFn fn = new StateTestingDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), @@ -578,7 +578,7 @@ public void processElement(ProcessContext c) { } DoFn fn = new RepeaterDoFn(); - DoFnInfo fnInfo = + DoFnInfo fnInfo = DoFnInfo.forFn( fn, WindowingStrategy.globalDefault(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index e16a8b9f88cf..df90bb96139d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3276,6 +3276,9 @@ public void testExceptionInvalidatesCache() throws Exception { TestCountingSource counter = new TestCountingSource(3).withThrowOnFirstSnapshot(true); + // Reset static state that may leak across tests. + TestExceptionInvalidatesCacheFn.resetStaticState(); + TestCountingSource.resetStaticState(); List instructions = Arrays.asList( new ParallelInstruction() @@ -3310,7 +3313,10 @@ public void testExceptionInvalidatesCache() throws Exception { .build()); worker.start(); - // Three GetData requests + // Three GetData requests: + // - first processing has no state + // - recovering from checkpoint exception has no persisted state + // - recovering from processing exception recovers last committed state for (int i = 0; i < 3; i++) { ByteString state; if (i == 0 || i == 1) { @@ -3437,6 +3443,11 @@ public void testExceptionInvalidatesCache() throws Exception { parseCommitRequest(sb.toString())) .build())); } + + // Ensure that the invalidated dofn had tearDown called on them. + assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get()); + assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get()); + worker.stop(); } @@ -3484,7 +3495,7 @@ public void testActiveWorkRefresh() throws Exception { } @Test - public void testActiveWorkFailure() throws Exception { + public void testQueuedWorkFailure() throws Exception { List instructions = Arrays.asList( makeSourceInstruction(StringUtf8Coder.of()), @@ -3515,6 +3526,9 @@ public void testActiveWorkFailure() throws Exception { server.whenGetWorkCalled().thenReturn(workItem).thenReturn(workItemToFail); server.waitForEmptyWorkQueue(); + // Wait for key to schedule, it will be blocked. + BlockingFn.counter().acquire(1); + // Mock Windmill sending a heartbeat response failing the second work item while the first // is still processing. ComputationHeartbeatResponse.Builder failedHeartbeat = @@ -3534,6 +3548,64 @@ public void testActiveWorkFailure() throws Exception { server.waitForAndGetCommitsWithTimeout(1, Duration.standardSeconds((5))); assertEquals(1, commits.size()); + assertEquals(0, BlockingFn.teardownCounter.get()); + assertEquals(1, BlockingFn.setupCounter.get()); + + worker.stop(); + } + + @Test + public void testActiveWorkFailure() throws Exception { + List instructions = + Arrays.asList( + makeSourceInstruction(StringUtf8Coder.of()), + makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()), + makeSinkInstruction(StringUtf8Coder.of(), 0)); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams("--activeWorkRefreshPeriodMillis=100") + .setInstructions(instructions) + .publishCounters() + .build()); + worker.start(); + + GetWorkResponse workItemToFail = + makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", DEFAULT_SHARDING_KEY); + long failedWorkToken = workItemToFail.getWork(0).getWork(0).getWorkToken(); + long failedCacheToken = workItemToFail.getWork(0).getWork(0).getCacheToken(); + GetWorkResponse workItem = + makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key", DEFAULT_SHARDING_KEY); + + // Queue up the work item for the key. + server.whenGetWorkCalled().thenReturn(workItemToFail).thenReturn(workItem); + server.waitForEmptyWorkQueue(); + + // Wait for key to schedule, it will be blocked. + BlockingFn.counter().acquire(1); + + // Mock Windmill sending a heartbeat response failing the first work item while it is + // is processing. + ComputationHeartbeatResponse.Builder failedHeartbeat = + ComputationHeartbeatResponse.newBuilder(); + failedHeartbeat + .setComputationId(DEFAULT_COMPUTATION_ID) + .addHeartbeatResponsesBuilder() + .setCacheToken(failedCacheToken) + .setWorkToken(failedWorkToken) + .setShardingKey(DEFAULT_SHARDING_KEY) + .setFailed(true); + server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build())); + + // Release the blocked call, there should not be a commit and the dofn should be invalidated. + BlockingFn.blocker().countDown(); + Map commits = + server.waitForAndGetCommitsWithTimeout(1, Duration.standardSeconds((5))); + assertEquals(1, commits.size()); + + assertEquals(0, BlockingFn.teardownCounter.get()); + assertEquals(1, BlockingFn.setupCounter.get()); + worker.stop(); } @@ -4246,6 +4318,18 @@ static class BlockingFn extends DoFn implements TestRule { new AtomicReference<>(new CountDownLatch(1)); public static AtomicReference counter = new AtomicReference<>(new Semaphore(0)); public static AtomicInteger callCounter = new AtomicInteger(0); + public static AtomicInteger setupCounter = new AtomicInteger(0); + public static AtomicInteger teardownCounter = new AtomicInteger(0); + + @Setup + public void setup() { + setupCounter.incrementAndGet(); + } + + @Teardown + public void tearDown() { + teardownCounter.incrementAndGet(); + } @ProcessElement public void processElement(ProcessContext c) throws InterruptedException { @@ -4278,6 +4362,8 @@ public void evaluate() throws Throwable { blocker.set(new CountDownLatch(1)); counter.set(new Semaphore(0)); callCounter.set(0); + setupCounter.set(0); + teardownCounter.set(0); } } }; @@ -4397,11 +4483,33 @@ public void processElement(ProcessContext c) { static class TestExceptionInvalidatesCacheFn extends DoFn>, String> { - static boolean thrown = false; + public static AtomicInteger setupCallCount = new AtomicInteger(); + public static AtomicInteger tearDownCallCount = new AtomicInteger(); + private static boolean thrown = false; + private boolean setupCalled = false; + + static void resetStaticState() { + setupCallCount.set(0); + tearDownCallCount.set(0); + thrown = false; + } @StateId("int") private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); + @Setup + public void setUp() { + assertFalse(setupCalled); + setupCalled = true; + setupCallCount.addAndGet(1); + } + + @Teardown + public void tearDown() { + assertTrue(setupCalled); + tearDownCallCount.addAndGet(1); + } + @ProcessElement public void processElement(ProcessContext c, @StateId("int") ValueState state) throws Exception { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java index 6771e9dbb713..21e4d8c55e70 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java @@ -65,6 +65,11 @@ public static void setFinalizeTracker(List finalizeTracker) { TestCountingSource.finalizeTracker = finalizeTracker; } + public static void resetStaticState() { + finalizeTracker = null; + thrown = false; + } + public TestCountingSource(int numMessagesPerShard) { this(numMessagesPerShard, 0, false, false, true); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java index 2eeaa06eb5eb..188466a50572 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java @@ -519,4 +519,43 @@ public void testAbort() throws Exception { Mockito.verify(o2, atLeastOnce()).abortReadLoop(); Mockito.verify(stateTracker).deactivate(); } + + @Test + public void testCloseAbortsOperations() throws Exception { + Operation o1 = Mockito.mock(Operation.class); + Operation o2 = Mockito.mock(Operation.class); + List operations = Arrays.asList(o1, o2); + ExecutionStateTracker stateTracker = Mockito.spy(ExecutionStateTracker.newForTest()); + Mockito.verifyNoMoreInteractions(stateTracker); + try (MapTaskExecutor executor = new MapTaskExecutor(operations, counterSet, stateTracker)) {} + + Mockito.verify(o1).abort(); + Mockito.verify(o2).abort(); + } + + @Test + public void testExceptionAndThenCloseAbortsJustOnce() throws Exception { + Operation o1 = Mockito.mock(Operation.class); + Operation o2 = Mockito.mock(Operation.class); + Mockito.doThrow(new Exception("in start")).when(o2).start(); + + ExecutionStateTracker stateTracker = Mockito.spy(ExecutionStateTracker.newForTest()); + MapTaskExecutor executor = new MapTaskExecutor(Arrays.asList(o1, o2), counterSet, stateTracker); + try { + executor.execute(); + fail("Should have thrown"); + } catch (Exception e) { + } + InOrder inOrder = Mockito.inOrder(o2, stateTracker); + inOrder.verify(stateTracker).activate(); + inOrder.verify(o2).start(); + inOrder.verify(o2).abort(); + inOrder.verify(stateTracker).deactivate(); + + // Order of o1 abort doesn't matter + Mockito.verify(o1).abort(); + Mockito.verifyNoMoreInteractions(o1); + // Closing after already closed should not call abort again. + executor.close(); + } }