Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.sdk.util.construction.Timer;
import org.apache.beam.sdk.util.construction.graph.PipelineNode;
import org.apache.beam.sdk.util.construction.graph.QueryablePipeline;
import org.apache.beam.sdk.values.WindowedValue;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.TextFormat;
Expand Down Expand Up @@ -161,6 +163,17 @@ public class ProcessBundleHandler {
@VisibleForTesting final BundleProcessorCache bundleProcessorCache;
private final Set<String> runnerCapabilities;
private final @Nullable DataSampler dataSampler;
private final LoadingCache<String, TopologyCacheEntry> topologicalOrderCache;

private static class TopologyCacheEntry {
final ProcessBundleDescriptor descriptor;
final ImmutableList<String> order;

TopologyCacheEntry(ProcessBundleDescriptor descriptor, ImmutableList<String> order) {
this.descriptor = descriptor;
this.order = order;
}
}

public ProcessBundleHandler(
PipelineOptions options,
Expand Down Expand Up @@ -220,6 +233,43 @@ public ProcessBundleHandler(
this.processWideCache = processWideCache;
this.bundleProcessorCache = bundleProcessorCache;
this.dataSampler = dataSampler;

// Initialize topological-order cache. Use same timeout idiom as BundleProcessorCache.
CacheBuilder<Object, Object> topoBuilder = CacheBuilder.newBuilder();
Duration topoTimeout = options.as(SdkHarnessOptions.class).getBundleProcessorCacheTimeout();
if (topoTimeout.compareTo(Duration.ZERO) > 0) {
topoBuilder = topoBuilder.expireAfterAccess(topoTimeout);
}
this.topologicalOrderCache =
topoBuilder.build(
new CacheLoader<String, TopologyCacheEntry>() {
@Override
public TopologyCacheEntry load(String descriptorId) throws Exception {
ProcessBundleDescriptor desc = fnApiRegistry.apply(descriptorId);
RunnerApi.Components comps =
RunnerApi.Components.newBuilder()
.putAllCoders(desc.getCodersMap())
.putAllPcollections(desc.getPcollectionsMap())
.putAllTransforms(desc.getTransformsMap())
.putAllWindowingStrategies(desc.getWindowingStrategiesMap())
.build();
QueryablePipeline qp =
QueryablePipeline.forTransforms(desc.getTransformsMap().keySet(), comps);
ImmutableList.Builder<String> ids = ImmutableList.builder();
for (PipelineNode.PTransformNode node : qp.getTopologicallyOrderedTransforms()) {
ids.add(node.getId());
}
ImmutableList<String> topo = ids.build();
// Treat incomplete topo as a cycle/error so loader fails and caller falls back.
if (topo.size() != desc.getTransformsMap().size()) {
throw new IllegalStateException(
String.format(
"Topological ordering incomplete for descriptor %s: %d of %d",
descriptorId, topo.size(), desc.getTransformsMap().size()));
}
return new TopologyCacheEntry(desc, topo);
}
});
}

private void addRunnerAndConsumersForPTransformRecursively(
Expand Down Expand Up @@ -770,23 +820,50 @@ public void discard() {

private BundleProcessor createBundleProcessor(
String bundleId, ProcessBundleRequest processBundleRequest) throws IOException {
ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId);

SetMultimap<String, String> pCollectionIdsToConsumingPTransforms = HashMultimap.create();
// Prepare per-bundle state trackers / registries first.
BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar =
new BundleProgressReporter.InMemory();
MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle =
new MetricsEnvironmentStateForBundle();
ExecutionStateTracker stateTracker = executionStateSampler.create();
bundleProgressReporterAndRegistrar.register(stateTracker);
HashSet<String> processedPTransformIds = new HashSet<>();

// Resolve descriptor + transform order from cache (descriptor+topo cached together) or
// fall back to single-fetch descriptor + descriptor order.
ProcessBundleDescriptor bundleDescriptor;
Iterable<String> transformIds;
try {
TopologyCacheEntry entry = topologicalOrderCache.get(bundleId);
bundleDescriptor = entry.descriptor;
transformIds = entry.order;
} catch (Exception e) {
LOG.warn(
"Topological ordering failed for descriptor {}. Falling back to descriptor order. Cause: {}",
bundleId,
e.toString());
// Fall back: fetch descriptor once and use descriptor-order iteration.
bundleDescriptor = fnApiRegistry.apply(bundleId);
transformIds = bundleDescriptor.getTransformsMap().keySet();
}

// Build a multimap of PCollection ids to PTransform ids which consume said PCollections
SetMultimap<String, String> pCollectionIdsToConsumingPTransforms = HashMultimap.create();
for (Map.Entry<String, PTransform> entry : bundleDescriptor.getTransformsMap().entrySet()) {
for (String pCollectionId : entry.getValue().getInputsMap().values()) {
pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey());
}
}

// Now that bundleDescriptor is known, construct the consumer registry.
PCollectionConsumerRegistry pCollectionConsumerRegistry =
new PCollectionConsumerRegistry(
stateTracker,
shortIds,
bundleProgressReporterAndRegistrar,
bundleDescriptor,
dataSampler);
HashSet<String> processedPTransformIds = new HashSet<>();

PTransformFunctionRegistry startFunctionRegistry =
new PTransformFunctionRegistry(shortIds, stateTracker, Urns.START_BUNDLE_MSECS);
Expand All @@ -795,13 +872,6 @@ private BundleProcessor createBundleProcessor(
List<ThrowingRunnable> resetFunctions = new ArrayList<>();
List<ThrowingRunnable> tearDownFunctions = new ArrayList<>();

// Build a multimap of PCollection ids to PTransform ids which consume said PCollections
for (Map.Entry<String, PTransform> entry : bundleDescriptor.getTransformsMap().entrySet()) {
for (String pCollectionId : entry.getValue().getInputsMap().values()) {
pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey());
}
}

// Instantiate a State API call handler depending on whether a State ApiServiceDescriptor was
// specified.
HandleStateCallsForBundle beamFnStateClient;
Expand Down Expand Up @@ -843,32 +913,31 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) {
bundleFinalizationCallbackRegistrations,
runnerCapabilities);

// Create a BeamFnStateClient
for (Map.Entry<String, PTransform> entry : bundleDescriptor.getTransformsMap().entrySet()) {

// Skip anything which isn't a root.
// Also force data output transforms to be unconditionally instantiated (see BEAM-10450).
// TODO: Remove source as a root and have it be triggered by the Runner.
if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn())
&& !DATA_OUTPUT_URN.equals(entry.getValue().getSpec().getUrn())
&& !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn())
&& !PTransformTranslation.READ_TRANSFORM_URN.equals(
entry.getValue().getSpec().getUrn())) {
// Build components once for this descriptor.
final RunnerApi.Components components =
RunnerApi.Components.newBuilder()
.putAllCoders(bundleDescriptor.getCodersMap())
.putAllPcollections(bundleDescriptor.getPcollectionsMap())
.putAllWindowingStrategies(bundleDescriptor.getWindowingStrategiesMap())
.build();

for (String transformId : transformIds) {
PTransform pTransform = bundleDescriptor.getTransformsMap().get(transformId);
if (pTransform == null) {
continue; // defensive
}
if (!DATA_INPUT_URN.equals(pTransform.getSpec().getUrn())
&& !DATA_OUTPUT_URN.equals(pTransform.getSpec().getUrn())
&& !JAVA_SOURCE_URN.equals(pTransform.getSpec().getUrn())
&& !PTransformTranslation.READ_TRANSFORM_URN.equals(pTransform.getSpec().getUrn())) {
continue;
}

RunnerApi.Components components =
RunnerApi.Components.newBuilder()
.putAllCoders(bundleDescriptor.getCodersMap())
.putAllPcollections(bundleDescriptor.getPcollectionsMap())
.putAllWindowingStrategies(bundleDescriptor.getWindowingStrategiesMap())
.build();

ProcessBundleDescriptor finalBundleDescriptor = bundleDescriptor;
addRunnerAndConsumersForPTransformRecursively(
beamFnStateClient,
beamFnDataClient,
entry.getKey(),
entry.getValue(),
transformId,
pTransform,
bundleProcessor::getInstructionId,
bundleProcessor::getCacheTokens,
bundleProcessor::getBundleCache,
Expand All @@ -890,7 +959,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) {
bundleProcessor.getInboundDataEndpoints().add(dataEndpoint);
},
(timerEndpoint) -> {
if (!bundleDescriptor.hasTimerApiServiceDescriptor()) {
if (!finalBundleDescriptor.hasTimerApiServiceDescriptor()) {
throw new IllegalStateException(
String.format(
"Timers are unsupported because the "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.BeamFnDataReadRunner;
import org.apache.beam.fn.harness.Cache;
Expand Down Expand Up @@ -146,6 +147,8 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
Expand Down Expand Up @@ -1929,4 +1932,157 @@ public void testTimerRegistrationsFailIfNoTimerApiServiceDescriptorSpecified() t
private static void throwException() {
throw new IllegalStateException("TestException");
}

@Test
public void testTopologicalOrderRespectsDependency() throws Exception {
// Build a descriptor A -> B -> C
ProcessBundleDescriptor processBundleDescriptor =
ProcessBundleDescriptor.newBuilder()
.putTransforms(
"A",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.putOutputs("A-out", "A-out-pc")
.build())
.putTransforms(
"B",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build())
.putInputs("B-in", "A-out-pc")
.putOutputs("B-out", "B-out-pc")
.build())
.putTransforms(
"C",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build())
.putInputs("C-in", "B-out-pc")
.build())
.putPcollections("A-out-pc", PCollection.getDefaultInstance())
.putPcollections("B-out-pc", PCollection.getDefaultInstance())
.build();

Map<String, ProcessBundleDescriptor> registry =
ImmutableMap.of("chain", processBundleDescriptor);
final AtomicInteger calls = new AtomicInteger(0);
Function<String, ProcessBundleDescriptor> fnApiRegistry =
id -> {
calls.incrementAndGet();
return registry.get(id);
};

ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
Collections.emptySet(),
fnApiRegistry::apply,
beamFnDataClient,
null,
null,
new ShortIdMap(),
executionStateSampler,
ImmutableMap.of(DATA_INPUT_URN, (context) -> {}, DATA_OUTPUT_URN, (context) -> {}),
Caches.noop(),
new BundleProcessorCache(Duration.ZERO),
null);

// Access the private topologicalOrderCache and verify ordering
java.lang.reflect.Field f =
ProcessBundleHandler.class.getDeclaredField("topologicalOrderCache");
f.setAccessible(true);
@SuppressWarnings("unchecked")
LoadingCache<String, ?> cache = (LoadingCache<String, ?>) f.get(handler);

// Cache holds a TopologyCacheEntry; extract its 'order' field reflectively.
Object entry = cache.get("chain");
java.lang.reflect.Field orderField = entry.getClass().getDeclaredField("order");
orderField.setAccessible(true);
@SuppressWarnings("unchecked")
ImmutableList<String> topo = (ImmutableList<String>) orderField.get(entry);

// Cover all transforms
assertEquals(processBundleDescriptor.getTransformsMap().size(), topo.size());
// Ensure producer -> consumer ordering: A before B before C
assertTrue(topo.indexOf("A") >= 0);
assertTrue(topo.indexOf("B") >= 0);
assertTrue(topo.indexOf("C") >= 0);
assertTrue(topo.indexOf("A") < topo.indexOf("B"));
assertTrue(topo.indexOf("B") < topo.indexOf("C"));
// Loader should have invoked fnApiRegistry exactly once.
assertEquals(1, calls.get());
}

@Test
public void testProcessBundleCreatesRunnersForAllTransformsUsingTopologicalCache()
throws Exception {
// Build a descriptor A -> B -> C
ProcessBundleDescriptor processBundleDescriptor =
ProcessBundleDescriptor.newBuilder()
.putTransforms(
"A",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.putOutputs("A-out", "A-out-pc")
.build())
.putTransforms(
"B",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build())
.putInputs("B-in", "A-out-pc")
.putOutputs("B-out", "B-out-pc")
.build())
.putTransforms(
"C",
PTransform.newBuilder()
.setSpec(FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build())
.putInputs("C-in", "B-out-pc")
.build())
.putPcollections("A-out-pc", PCollection.getDefaultInstance())
.putPcollections("B-out-pc", PCollection.getDefaultInstance())
.build();

Map<String, ProcessBundleDescriptor> registry =
ImmutableMap.of("chain", processBundleDescriptor);
final AtomicInteger calls = new AtomicInteger(0);
Function<String, ProcessBundleDescriptor> fnApiRegistry =
id -> {
calls.incrementAndGet();
return registry.get(id);
};

// Record which transforms had runners created.
final List<String> transformsProcessed = new ArrayList<>();
PTransformRunnerFactory recorderFactory =
(context) -> transformsProcessed.add(context.getPTransformId());

ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
Collections.emptySet(),
fnApiRegistry::apply,
beamFnDataClient,
null,
null,
new ShortIdMap(),
executionStateSampler,
ImmutableMap.of(DATA_INPUT_URN, recorderFactory, DATA_OUTPUT_URN, recorderFactory),
Caches.noop(),
new BundleProcessorCache(Duration.ZERO),
null);

// processBundle should cause creation of runners for all transforms
handler.processBundle(
InstructionRequest.newBuilder()
.setInstructionId("instr-chain")
.setProcessBundle(
ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("chain"))
.build());

// All transforms should have had their runner factory invoked.
assertEquals(processBundleDescriptor.getTransformsMap().size(), transformsProcessed.size());
assertTrue(transformsProcessed.contains("A"));
assertTrue(transformsProcessed.contains("B"));
assertTrue(transformsProcessed.contains("C"));
// fnApiRegistry should have been consulted exactly once for the descriptor during cache load.
assertEquals(1, calls.get());
}
}
Loading