diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index fe422939e535..9983bc6cd3b0 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -86,6 +86,8 @@ import org.apache.beam.sdk.util.construction.BeamUrns; import org.apache.beam.sdk.util.construction.PTransformTranslation; import org.apache.beam.sdk.util.construction.Timer; +import org.apache.beam.sdk.util.construction.graph.PipelineNode; +import org.apache.beam.sdk.util.construction.graph.QueryablePipeline; import org.apache.beam.sdk.values.WindowedValue; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.TextFormat; @@ -161,6 +163,17 @@ public class ProcessBundleHandler { @VisibleForTesting final BundleProcessorCache bundleProcessorCache; private final Set runnerCapabilities; private final @Nullable DataSampler dataSampler; + private final LoadingCache topologicalOrderCache; + + private static class TopologyCacheEntry { + final ProcessBundleDescriptor descriptor; + final ImmutableList order; + + TopologyCacheEntry(ProcessBundleDescriptor descriptor, ImmutableList order) { + this.descriptor = descriptor; + this.order = order; + } + } public ProcessBundleHandler( PipelineOptions options, @@ -220,6 +233,43 @@ public ProcessBundleHandler( this.processWideCache = processWideCache; this.bundleProcessorCache = bundleProcessorCache; this.dataSampler = dataSampler; + + // Initialize topological-order cache. Use same timeout idiom as BundleProcessorCache. + CacheBuilder topoBuilder = CacheBuilder.newBuilder(); + Duration topoTimeout = options.as(SdkHarnessOptions.class).getBundleProcessorCacheTimeout(); + if (topoTimeout.compareTo(Duration.ZERO) > 0) { + topoBuilder = topoBuilder.expireAfterAccess(topoTimeout); + } + this.topologicalOrderCache = + topoBuilder.build( + new CacheLoader() { + @Override + public TopologyCacheEntry load(String descriptorId) throws Exception { + ProcessBundleDescriptor desc = fnApiRegistry.apply(descriptorId); + RunnerApi.Components comps = + RunnerApi.Components.newBuilder() + .putAllCoders(desc.getCodersMap()) + .putAllPcollections(desc.getPcollectionsMap()) + .putAllTransforms(desc.getTransformsMap()) + .putAllWindowingStrategies(desc.getWindowingStrategiesMap()) + .build(); + QueryablePipeline qp = + QueryablePipeline.forTransforms(desc.getTransformsMap().keySet(), comps); + ImmutableList.Builder ids = ImmutableList.builder(); + for (PipelineNode.PTransformNode node : qp.getTopologicallyOrderedTransforms()) { + ids.add(node.getId()); + } + ImmutableList topo = ids.build(); + // Treat incomplete topo as a cycle/error so loader fails and caller falls back. + if (topo.size() != desc.getTransformsMap().size()) { + throw new IllegalStateException( + String.format( + "Topological ordering incomplete for descriptor %s: %d of %d", + descriptorId, topo.size(), desc.getTransformsMap().size())); + } + return new TopologyCacheEntry(desc, topo); + } + }); } private void addRunnerAndConsumersForPTransformRecursively( @@ -770,15 +820,43 @@ public void discard() { private BundleProcessor createBundleProcessor( String bundleId, ProcessBundleRequest processBundleRequest) throws IOException { - ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); - SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); + // Prepare per-bundle state trackers / registries first. BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = new BundleProgressReporter.InMemory(); MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle = new MetricsEnvironmentStateForBundle(); ExecutionStateTracker stateTracker = executionStateSampler.create(); bundleProgressReporterAndRegistrar.register(stateTracker); + HashSet processedPTransformIds = new HashSet<>(); + + // Resolve descriptor + transform order from cache (descriptor+topo cached together) or + // fall back to single-fetch descriptor + descriptor order. + ProcessBundleDescriptor bundleDescriptor; + Iterable transformIds; + try { + TopologyCacheEntry entry = topologicalOrderCache.get(bundleId); + bundleDescriptor = entry.descriptor; + transformIds = entry.order; + } catch (Exception e) { + LOG.warn( + "Topological ordering failed for descriptor {}. Falling back to descriptor order. Cause: {}", + bundleId, + e.toString()); + // Fall back: fetch descriptor once and use descriptor-order iteration. + bundleDescriptor = fnApiRegistry.apply(bundleId); + transformIds = bundleDescriptor.getTransformsMap().keySet(); + } + + // Build a multimap of PCollection ids to PTransform ids which consume said PCollections + SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { + for (String pCollectionId : entry.getValue().getInputsMap().values()) { + pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); + } + } + + // Now that bundleDescriptor is known, construct the consumer registry. PCollectionConsumerRegistry pCollectionConsumerRegistry = new PCollectionConsumerRegistry( stateTracker, @@ -786,7 +864,6 @@ private BundleProcessor createBundleProcessor( bundleProgressReporterAndRegistrar, bundleDescriptor, dataSampler); - HashSet processedPTransformIds = new HashSet<>(); PTransformFunctionRegistry startFunctionRegistry = new PTransformFunctionRegistry(shortIds, stateTracker, Urns.START_BUNDLE_MSECS); @@ -795,13 +872,6 @@ private BundleProcessor createBundleProcessor( List resetFunctions = new ArrayList<>(); List tearDownFunctions = new ArrayList<>(); - // Build a multimap of PCollection ids to PTransform ids which consume said PCollections - for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { - for (String pCollectionId : entry.getValue().getInputsMap().values()) { - pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); - } - } - // Instantiate a State API call handler depending on whether a State ApiServiceDescriptor was // specified. HandleStateCallsForBundle beamFnStateClient; @@ -843,32 +913,31 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { bundleFinalizationCallbackRegistrations, runnerCapabilities); - // Create a BeamFnStateClient - for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { - - // Skip anything which isn't a root. - // Also force data output transforms to be unconditionally instantiated (see BEAM-10450). - // TODO: Remove source as a root and have it be triggered by the Runner. - if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) - && !DATA_OUTPUT_URN.equals(entry.getValue().getSpec().getUrn()) - && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn()) - && !PTransformTranslation.READ_TRANSFORM_URN.equals( - entry.getValue().getSpec().getUrn())) { + // Build components once for this descriptor. + final RunnerApi.Components components = + RunnerApi.Components.newBuilder() + .putAllCoders(bundleDescriptor.getCodersMap()) + .putAllPcollections(bundleDescriptor.getPcollectionsMap()) + .putAllWindowingStrategies(bundleDescriptor.getWindowingStrategiesMap()) + .build(); + + for (String transformId : transformIds) { + PTransform pTransform = bundleDescriptor.getTransformsMap().get(transformId); + if (pTransform == null) { + continue; // defensive + } + if (!DATA_INPUT_URN.equals(pTransform.getSpec().getUrn()) + && !DATA_OUTPUT_URN.equals(pTransform.getSpec().getUrn()) + && !JAVA_SOURCE_URN.equals(pTransform.getSpec().getUrn()) + && !PTransformTranslation.READ_TRANSFORM_URN.equals(pTransform.getSpec().getUrn())) { continue; } - - RunnerApi.Components components = - RunnerApi.Components.newBuilder() - .putAllCoders(bundleDescriptor.getCodersMap()) - .putAllPcollections(bundleDescriptor.getPcollectionsMap()) - .putAllWindowingStrategies(bundleDescriptor.getWindowingStrategiesMap()) - .build(); - + ProcessBundleDescriptor finalBundleDescriptor = bundleDescriptor; addRunnerAndConsumersForPTransformRecursively( beamFnStateClient, beamFnDataClient, - entry.getKey(), - entry.getValue(), + transformId, + pTransform, bundleProcessor::getInstructionId, bundleProcessor::getCacheTokens, bundleProcessor::getBundleCache, @@ -890,7 +959,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { bundleProcessor.getInboundDataEndpoints().add(dataEndpoint); }, (timerEndpoint) -> { - if (!bundleDescriptor.hasTimerApiServiceDescriptor()) { + if (!finalBundleDescriptor.hasTimerApiServiceDescriptor()) { throw new IllegalStateException( String.format( "Timers are unsupported because the " diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index a7a62571e38e..1b0aa3a2726a 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -66,6 +66,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.BeamFnDataReadRunner; import org.apache.beam.fn.harness.Cache; @@ -146,6 +147,8 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -1929,4 +1932,157 @@ public void testTimerRegistrationsFailIfNoTimerApiServiceDescriptorSpecified() t private static void throwException() { throw new IllegalStateException("TestException"); } + + @Test + public void testTopologicalOrderRespectsDependency() throws Exception { + // Build a descriptor A -> B -> C + ProcessBundleDescriptor processBundleDescriptor = + ProcessBundleDescriptor.newBuilder() + .putTransforms( + "A", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .putOutputs("A-out", "A-out-pc") + .build()) + .putTransforms( + "B", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()) + .putInputs("B-in", "A-out-pc") + .putOutputs("B-out", "B-out-pc") + .build()) + .putTransforms( + "C", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()) + .putInputs("C-in", "B-out-pc") + .build()) + .putPcollections("A-out-pc", PCollection.getDefaultInstance()) + .putPcollections("B-out-pc", PCollection.getDefaultInstance()) + .build(); + + Map registry = + ImmutableMap.of("chain", processBundleDescriptor); + final AtomicInteger calls = new AtomicInteger(0); + Function fnApiRegistry = + id -> { + calls.incrementAndGet(); + return registry.get(id); + }; + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + Collections.emptySet(), + fnApiRegistry::apply, + beamFnDataClient, + null, + null, + new ShortIdMap(), + executionStateSampler, + ImmutableMap.of(DATA_INPUT_URN, (context) -> {}, DATA_OUTPUT_URN, (context) -> {}), + Caches.noop(), + new BundleProcessorCache(Duration.ZERO), + null); + + // Access the private topologicalOrderCache and verify ordering + java.lang.reflect.Field f = + ProcessBundleHandler.class.getDeclaredField("topologicalOrderCache"); + f.setAccessible(true); + @SuppressWarnings("unchecked") + LoadingCache cache = (LoadingCache) f.get(handler); + + // Cache holds a TopologyCacheEntry; extract its 'order' field reflectively. + Object entry = cache.get("chain"); + java.lang.reflect.Field orderField = entry.getClass().getDeclaredField("order"); + orderField.setAccessible(true); + @SuppressWarnings("unchecked") + ImmutableList topo = (ImmutableList) orderField.get(entry); + + // Cover all transforms + assertEquals(processBundleDescriptor.getTransformsMap().size(), topo.size()); + // Ensure producer -> consumer ordering: A before B before C + assertTrue(topo.indexOf("A") >= 0); + assertTrue(topo.indexOf("B") >= 0); + assertTrue(topo.indexOf("C") >= 0); + assertTrue(topo.indexOf("A") < topo.indexOf("B")); + assertTrue(topo.indexOf("B") < topo.indexOf("C")); + // Loader should have invoked fnApiRegistry exactly once. + assertEquals(1, calls.get()); + } + + @Test + public void testProcessBundleCreatesRunnersForAllTransformsUsingTopologicalCache() + throws Exception { + // Build a descriptor A -> B -> C + ProcessBundleDescriptor processBundleDescriptor = + ProcessBundleDescriptor.newBuilder() + .putTransforms( + "A", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .putOutputs("A-out", "A-out-pc") + .build()) + .putTransforms( + "B", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()) + .putInputs("B-in", "A-out-pc") + .putOutputs("B-out", "B-out-pc") + .build()) + .putTransforms( + "C", + PTransform.newBuilder() + .setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()) + .putInputs("C-in", "B-out-pc") + .build()) + .putPcollections("A-out-pc", PCollection.getDefaultInstance()) + .putPcollections("B-out-pc", PCollection.getDefaultInstance()) + .build(); + + Map registry = + ImmutableMap.of("chain", processBundleDescriptor); + final AtomicInteger calls = new AtomicInteger(0); + Function fnApiRegistry = + id -> { + calls.incrementAndGet(); + return registry.get(id); + }; + + // Record which transforms had runners created. + final List transformsProcessed = new ArrayList<>(); + PTransformRunnerFactory recorderFactory = + (context) -> transformsProcessed.add(context.getPTransformId()); + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + Collections.emptySet(), + fnApiRegistry::apply, + beamFnDataClient, + null, + null, + new ShortIdMap(), + executionStateSampler, + ImmutableMap.of(DATA_INPUT_URN, recorderFactory, DATA_OUTPUT_URN, recorderFactory), + Caches.noop(), + new BundleProcessorCache(Duration.ZERO), + null); + + // processBundle should cause creation of runners for all transforms + handler.processBundle( + InstructionRequest.newBuilder() + .setInstructionId("instr-chain") + .setProcessBundle( + ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("chain")) + .build()); + + // All transforms should have had their runner factory invoked. + assertEquals(processBundleDescriptor.getTransformsMap().size(), transformsProcessed.size()); + assertTrue(transformsProcessed.contains("A")); + assertTrue(transformsProcessed.contains("B")); + assertTrue(transformsProcessed.contains("C")); + // fnApiRegistry should have been consulted exactly once for the descriptor during cache load. + assertEquals(1, calls.get()); + } }