Skip to content

Commit 451ee3d

Browse files
authored
[Spark] Skip unused outputs of ParDo in SparkRunner (#33771) (#33772)
* [spark] Skip unused outputs of ParDo in SparkRunner * Update runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java Co-authored-by: Jan Lukavský <je.ik@seznam.cz> * [spark] spotless * [spark] Refactor according to review feedback * [spark] Fix compile and spotless
1 parent b76e45a commit 451ee3d

File tree

8 files changed

+337
-9
lines changed

8 files changed

+337
-9
lines changed

runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
116116

117117
// Filter out obsolete PCollections to only cache when absolutely necessary
118118
Map<TupleTag<?>, PCollection<?>> outputs =
119-
skipObsoleteOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt);
119+
skipUnconsumedOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt);
120120

121121
if (outputs.size() > 1) {
122122
// In case of multiple outputs / tags, map each tag to a column by index.
@@ -206,12 +206,12 @@ public Dataset<WindowedValue<T>> resolve(
206206
}
207207

208208
/**
209-
* Filter out obsolete, unused output tags except for {@code mainTag}.
209+
* Filter out output tags which are not consumed by any transform, except for {@code mainTag}.
210210
*
211211
* <p>This can help to avoid unnecessary caching in case of multiple outputs if only {@code
212212
* mainTag} is consumed.
213213
*/
214-
private Map<TupleTag<?>, PCollection<?>> skipObsoleteOutputs(
214+
private Map<TupleTag<?>, PCollection<?>> skipUnconsumedOutputs(
215215
Map<TupleTag<?>, PCollection<?>> outputs,
216216
TupleTag<?> mainTag,
217217
TupleTagList otherTags,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.runners.spark;
19+
20+
import java.util.Map;
21+
import org.apache.beam.runners.spark.translation.EvaluationContext;
22+
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
23+
import org.apache.beam.sdk.runners.TransformHierarchy;
24+
import org.apache.beam.sdk.values.PCollection;
25+
import org.apache.beam.sdk.values.TupleTag;
26+
27+
/**
28+
* Traverses the pipeline to populate information on how many {@link
29+
* org.apache.beam.sdk.transforms.PTransform}s do consume / depends on each {@link PCollection} in
30+
* the pipeline.
31+
*/
32+
class DependentTransformsVisitor extends SparkRunner.Evaluator {
33+
34+
DependentTransformsVisitor(
35+
SparkPipelineTranslator translator, EvaluationContext evaluationContext) {
36+
super(translator, evaluationContext);
37+
}
38+
39+
@Override
40+
public void doVisitTransform(TransformHierarchy.Node node) {
41+
42+
for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
43+
ctxt.reportPCollectionConsumed(entry.getValue());
44+
}
45+
46+
for (PCollection<?> pOut : node.getOutputs().values()) {
47+
ctxt.reportPCollectionProduced(pOut);
48+
}
49+
}
50+
}

runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ public SparkPipelineResult run(final Pipeline pipeline) {
214214

215215
// update the cache candidates
216216
updateCacheCandidates(pipeline, translator, evaluationContext);
217+
updateDependentTransforms(pipeline, translator, evaluationContext);
217218

218219
// update GBK candidates for memory optimized transform
219220
pipeline.traverseTopologically(new GroupByKeyVisitor(translator, evaluationContext));
@@ -275,8 +276,13 @@ static void detectTranslationMode(Pipeline pipeline, SparkPipelineOptions pipeli
275276
/** Evaluator that update/populate the cache candidates. */
276277
public static void updateCacheCandidates(
277278
Pipeline pipeline, SparkPipelineTranslator translator, EvaluationContext evaluationContext) {
278-
CacheVisitor cacheVisitor = new CacheVisitor(translator, evaluationContext);
279-
pipeline.traverseTopologically(cacheVisitor);
279+
pipeline.traverseTopologically(new CacheVisitor(translator, evaluationContext));
280+
}
281+
282+
/** Evaluator that update/populate information about dependent transforms for pCollections. */
283+
public static void updateDependentTransforms(
284+
Pipeline pipeline, SparkPipelineTranslator translator, EvaluationContext evaluationContext) {
285+
pipeline.traverseTopologically(new DependentTransformsVisitor(translator, evaluationContext));
280286
}
281287

282288
/** The translation mode of the Beam Pipeline. */

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
2121

22+
import java.util.Collections;
2223
import java.util.HashMap;
2324
import java.util.LinkedHashMap;
2425
import java.util.LinkedHashSet;
@@ -61,6 +62,7 @@ public class EvaluationContext {
6162
private final Map<PValue, Dataset> datasets = new LinkedHashMap<>();
6263
private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>();
6364
private final Set<Dataset> leaves = new LinkedHashSet<>();
65+
private final Map<PCollection<?>, Integer> pCollectionConsumptionMap = new HashMap<>();
6466
private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
6567
private AppliedPTransform<?, ?, ?> currentTransform;
6668
private final SparkPCollectionView pviews = new SparkPCollectionView();
@@ -307,6 +309,45 @@ public <K, V> boolean isCandidateForGroupByKeyAndWindow(GroupByKey<K, V> transfo
307309
return groupByKeyCandidatesForMemoryOptimizedTranslation.containsKey(transform);
308310
}
309311

312+
/**
313+
* Reports that given {@link PCollection} is consumed by a {@link PTransform} in the pipeline.
314+
*
315+
* @see #isLeaf(PCollection)
316+
*/
317+
public void reportPCollectionConsumed(PCollection<?> pCollection) {
318+
int count = this.pCollectionConsumptionMap.getOrDefault(pCollection, 0);
319+
this.pCollectionConsumptionMap.put(pCollection, count + 1);
320+
}
321+
322+
/**
323+
* Reports that given {@link PCollection} is consumed by a {@link PTransform} in the pipeline.
324+
*
325+
* @see #isLeaf(PCollection)
326+
*/
327+
public void reportPCollectionProduced(PCollection<?> pCollection) {
328+
this.pCollectionConsumptionMap.computeIfAbsent(pCollection, k -> 0);
329+
}
330+
331+
/**
332+
* Get the map of {@link PCollection} to the number of {@link PTransform} consuming it.
333+
*
334+
* @return
335+
*/
336+
public Map<PCollection<?>, Integer> getPCollectionConsumptionMap() {
337+
return Collections.unmodifiableMap(pCollectionConsumptionMap);
338+
}
339+
340+
/**
341+
* Check if given {@link PCollection} is a leaf or not. {@link PCollection} is a leaf when there
342+
* is no other {@link PTransform} consuming it / depending on it.
343+
*
344+
* @param pCollection to be checked if it is a leaf
345+
* @return true if pCollection is leaf; otherwise false
346+
*/
347+
public boolean isLeaf(PCollection<?> pCollection) {
348+
return this.pCollectionConsumptionMap.get(pCollection) == 0;
349+
}
350+
310351
<T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
311352
@SuppressWarnings("unchecked")
312353
BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection);

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.beam.runners.spark.translation;
1919

2020
import static org.apache.beam.runners.spark.translation.TranslationUtils.canAvoidRddSerialization;
21+
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
2122
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
2223

2324
import java.util.Arrays;
@@ -70,13 +71,15 @@
7071
import org.apache.beam.sdk.values.PCollection;
7172
import org.apache.beam.sdk.values.PCollectionView;
7273
import org.apache.beam.sdk.values.TupleTag;
74+
import org.apache.beam.sdk.values.TupleTagList;
7375
import org.apache.beam.sdk.values.WindowingStrategy;
7476
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
7577
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
7678
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
7779
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
7880
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
7981
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
82+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
8083
import org.apache.spark.HashPartitioner;
8184
import org.apache.spark.Partitioner;
8285
import org.apache.spark.api.java.JavaPairRDD;
@@ -428,13 +431,14 @@ public void evaluate(
428431
Map<String, PCollectionView<?>> sideInputMapping =
429432
ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
430433

434+
TupleTag<OutputT> mainOutputTag = transform.getMainOutputTag();
431435
MultiDoFnFunction<InputT, OutputT> multiDoFnFunction =
432436
new MultiDoFnFunction<>(
433437
metricsAccum,
434438
stepName,
435439
doFn,
436440
context.getSerializableOptions(),
437-
transform.getMainOutputTag(),
441+
mainOutputTag,
438442
transform.getAdditionalOutputTags().getAll(),
439443
inputCoder,
440444
outputCoders,
@@ -460,7 +464,13 @@ public void evaluate(
460464
all = inRDD.mapPartitionsToPair(multiDoFnFunction);
461465
}
462466

463-
Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs(transform);
467+
// Filter out obsolete PCollections to only cache when absolutely necessary
468+
Map<TupleTag<?>, PCollection<?>> outputs =
469+
skipUnconsumedOutputs(
470+
context.getOutputs(transform),
471+
mainOutputTag,
472+
transform.getAdditionalOutputTags(),
473+
context);
464474
if (hasMultipleOutputs(outputs)) {
465475
StorageLevel level = StorageLevel.fromString(context.storageLevel());
466476
if (canAvoidRddSerialization(level)) {
@@ -498,6 +508,37 @@ private boolean hasMultipleOutputs(Map<TupleTag<?>, PCollection<?>> outputs) {
498508
return outputs.size() > 1;
499509
}
500510

511+
/**
512+
* Filter out output tags which are not consumed by any transform, except for {@code mainTag}.
513+
*
514+
* <p>This can help to avoid unnecessary caching in case of multiple outputs if only {@code
515+
* mainTag} is consumed.
516+
*/
517+
private Map<TupleTag<?>, PCollection<?>> skipUnconsumedOutputs(
518+
Map<TupleTag<?>, PCollection<?>> outputs,
519+
TupleTag<?> mainTag,
520+
TupleTagList otherTags,
521+
EvaluationContext cxt) {
522+
switch (outputs.size()) {
523+
case 1:
524+
return outputs; // always keep main output
525+
case 2:
526+
TupleTag<?> otherTag = otherTags.get(0);
527+
return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag)))
528+
? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag)))
529+
: outputs;
530+
default:
531+
Map<TupleTag<?>, PCollection<?>> filtered =
532+
Maps.newHashMapWithExpectedSize(outputs.size());
533+
for (Map.Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) {
534+
if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) {
535+
filtered.put(e.getKey(), e.getValue());
536+
}
537+
}
538+
return filtered;
539+
}
540+
}
541+
501542
@Override
502543
public String toNativeString() {
503544
return "mapPartitions(new <fn>())";
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.runners.spark;
19+
20+
import static org.junit.Assert.assertEquals;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import org.apache.beam.runners.spark.translation.EvaluationContext;
25+
import org.apache.beam.runners.spark.translation.TransformTranslator;
26+
import org.apache.beam.sdk.Pipeline;
27+
import org.apache.beam.sdk.coders.VarLongCoder;
28+
import org.apache.beam.sdk.transforms.Count;
29+
import org.apache.beam.sdk.transforms.Create;
30+
import org.apache.beam.sdk.transforms.DoFn;
31+
import org.apache.beam.sdk.transforms.ParDo;
32+
import org.apache.beam.sdk.transforms.Sum;
33+
import org.apache.beam.sdk.transforms.View;
34+
import org.apache.beam.sdk.values.PCollection;
35+
import org.apache.beam.sdk.values.PCollectionTuple;
36+
import org.apache.beam.sdk.values.PCollectionView;
37+
import org.apache.beam.sdk.values.TupleTag;
38+
import org.apache.beam.sdk.values.TupleTagList;
39+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
40+
import org.junit.ClassRule;
41+
import org.junit.Rule;
42+
import org.junit.Test;
43+
import org.junit.rules.TemporaryFolder;
44+
45+
/** Tests of {@link DependentTransformsVisitor}. */
46+
public class DependentTransformsVisitorTest {
47+
48+
@ClassRule public static SparkContextRule contextRule = new SparkContextRule();
49+
50+
@Rule public TemporaryFolder tmpFolder = new TemporaryFolder();
51+
52+
@Test
53+
public void testCountDependentTransformsOnApplyAndSideInputs() {
54+
SparkPipelineOptions options = contextRule.createPipelineOptions();
55+
Pipeline pipeline = Pipeline.create(options);
56+
PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar"));
57+
58+
// First use of pCollection.
59+
PCollection<Long> leaf1 = pCollection.apply(Count.globally());
60+
// Second use of pCollection.
61+
PCollectionView<List<String>> view = pCollection.apply("yyy", View.asList());
62+
63+
PCollection<String> leaf2 =
64+
pipeline
65+
.apply(Create.of("foo", "baz"))
66+
.apply(
67+
ParDo.of(
68+
new DoFn<String, String>() {
69+
@ProcessElement
70+
public void processElement(ProcessContext processContext) {
71+
if (processContext.sideInput(view).contains(processContext.element())) {
72+
processContext.output(processContext.element());
73+
}
74+
}
75+
})
76+
.withSideInputs(view));
77+
78+
EvaluationContext ctxt =
79+
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
80+
TransformTranslator.Translator translator = new TransformTranslator.Translator();
81+
pipeline.traverseTopologically(new DependentTransformsVisitor(translator, ctxt));
82+
83+
assertEquals(2, ctxt.getPCollectionConsumptionMap().get(pCollection).intValue());
84+
assertEquals(0, ctxt.getPCollectionConsumptionMap().get(leaf1).intValue());
85+
assertEquals(0, ctxt.getPCollectionConsumptionMap().get(leaf2).intValue());
86+
assertEquals(2, ctxt.getPCollectionConsumptionMap().get(view.getPCollection()).intValue());
87+
}
88+
89+
@Test
90+
public void testCountDependentTransformsOnSideOutputs() {
91+
SparkPipelineOptions options = contextRule.createPipelineOptions();
92+
Pipeline pipeline = Pipeline.create(options);
93+
94+
TupleTag<String> passOutTag = new TupleTag<>("passOut");
95+
TupleTag<Long> lettersCountOutTag = new TupleTag<>("lettersOut");
96+
TupleTag<Long> wordCountOutTag = new TupleTag<>("wordsOut");
97+
98+
PCollectionTuple result =
99+
pipeline
100+
.apply(Create.of("foo", "baz"))
101+
.apply(
102+
ParDo.of(
103+
new DoFn<String, String>() {
104+
@ProcessElement
105+
public void processElement(ProcessContext processContext) {
106+
String element = processContext.element();
107+
processContext.output(element);
108+
processContext.output(
109+
lettersCountOutTag,
110+
(long) Objects.requireNonNull(element).length());
111+
processContext.output(wordCountOutTag, 1L);
112+
}
113+
})
114+
.withOutputTags(
115+
passOutTag,
116+
TupleTagList.of(Lists.newArrayList(lettersCountOutTag, wordCountOutTag))));
117+
118+
// consume main output and words side output. leave letters side output left alone
119+
result.get(wordCountOutTag).setCoder(VarLongCoder.of()).apply(Sum.longsGlobally());
120+
result.get(lettersCountOutTag).setCoder(VarLongCoder.of());
121+
result
122+
.get(passOutTag)
123+
.apply(
124+
ParDo.of(
125+
new DoFn<String, String>() {
126+
@ProcessElement
127+
public void processElement(ProcessContext processContext) {
128+
// do nothing
129+
}
130+
}));
131+
132+
EvaluationContext ctxt =
133+
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
134+
TransformTranslator.Translator translator = new TransformTranslator.Translator();
135+
pipeline.traverseTopologically(new DependentTransformsVisitor(translator, ctxt));
136+
137+
assertEquals(1, ctxt.getPCollectionConsumptionMap().get(result.get(passOutTag)).intValue());
138+
assertEquals(
139+
1, ctxt.getPCollectionConsumptionMap().get(result.get(wordCountOutTag)).intValue());
140+
assertEquals(
141+
0, ctxt.getPCollectionConsumptionMap().get(result.get(lettersCountOutTag)).intValue());
142+
}
143+
}

0 commit comments

Comments
 (0)