Skip to content

Commit 0b6a93f

Browse files
authored
[#37994] Fix NullPointerException in Spark Runner with multiple outputs and serialization (#38011)
* [#37994] Fix NullPointerException in Spark Runner with multiple outputs and serialization * trigger test.
1 parent 4da3e55 commit 0b6a93f

File tree

6 files changed

+87
-3
lines changed

6 files changed

+87
-3
lines changed

.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
"https://github.com/apache/beam/pull/34155": "noting that PR #34155 should run this test",
1414
"https://github.com/apache/beam/pull/34560": "noting that PR #34560 should run this test",
1515
"https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface",
16-
"https://github.com/apache/beam/pull/35316": "noting that PR #35316 should run this test"
16+
"https://github.com/apache/beam/pull/35316": "noting that PR #35316 should run this test",
17+
"https://github.com/apache/beam/pull/38011": "noting that PR #38011 should run this test"
1718
}

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.HashMap;
2828
import java.util.Iterator;
2929
import java.util.Map;
30+
import java.util.Objects;
3031
import org.apache.beam.runners.core.SystemReduceFn;
3132
import org.apache.beam.runners.spark.SparkPipelineOptions;
3233
import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -486,6 +487,9 @@ public void evaluate(
486487
TranslationUtils.getTupleTagCoders(outputs);
487488
all =
488489
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
490+
.filter(
491+
Objects::nonNull) // skip nulls to save on encoding, nulls are tags that are
492+
// not read
489493
.persist(level)
490494
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
491495
}

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,14 @@ public static Map<TupleTag<?>, Coder<WindowedValue<?>>> getTupleTagCoders(
445445
return tuple2 -> {
446446
TupleTag<?> tupleTag = tuple2._1;
447447
WindowedValue<?> windowedValue = tuple2._2;
448-
return new Tuple2<>(
449-
tupleTag, ValueAndCoderLazySerializable.of(windowedValue, coderMap.get(tupleTag)));
448+
Coder<WindowedValue<?>> coder = coderMap.get(tupleTag);
449+
if (coder == null) {
450+
// there is no coder as this output is unconsumed and is not read anywhere, so coder is
451+
// pruned
452+
// from coderMap
453+
return null;
454+
}
455+
return new Tuple2<>(tupleTag, ValueAndCoderLazySerializable.of(windowedValue, coder));
450456
};
451457
}
452458

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Iterator;
2727
import java.util.List;
2828
import java.util.Map;
29+
import java.util.Objects;
2930
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
3031
import org.apache.beam.runners.spark.coders.CoderHelpers;
3132
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
@@ -234,6 +235,9 @@ public void evaluate(
234235
TranslationUtils.getTupleTagCoders(outputs);
235236
all =
236237
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
238+
.filter(
239+
Objects
240+
::nonNull) // skip nulls to save on encoding, nulls are tags that are not read
237241
.cache()
238242
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
239243

runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Iterator;
2828
import java.util.List;
2929
import java.util.Map;
30+
import java.util.Objects;
3031
import java.util.Queue;
3132
import java.util.concurrent.LinkedBlockingQueue;
3233
import java.util.stream.Collectors;
@@ -593,6 +594,10 @@ public void evaluate(
593594
TranslationUtils.getTupleTagCoders(outputs);
594595
all =
595596
all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
597+
.filter(
598+
Objects
599+
::nonNull) // skip nulls to save on encoding, nulls are tags that are not
600+
// read
596601
.cache()
597602
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
598603
}

runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@
4040
import org.apache.beam.sdk.coders.VarIntCoder;
4141
import org.apache.beam.sdk.transforms.Count;
4242
import org.apache.beam.sdk.transforms.Create;
43+
import org.apache.beam.sdk.transforms.DoFn;
4344
import org.apache.beam.sdk.transforms.PTransform;
45+
import org.apache.beam.sdk.transforms.ParDo;
4446
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
4547
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
4648
import org.apache.beam.sdk.values.KV;
4749
import org.apache.beam.sdk.values.PBegin;
4850
import org.apache.beam.sdk.values.PCollection;
4951
import org.apache.beam.sdk.values.PCollectionTuple;
5052
import org.apache.beam.sdk.values.TupleTag;
53+
import org.apache.beam.sdk.values.TupleTagList;
5154
import org.apache.beam.sdk.values.WindowedValue;
5255
import org.apache.beam.sdk.values.WindowedValues;
5356
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
@@ -247,4 +250,65 @@ public void testMultipleOutputParDoShouldHaveFilterWhenSideOutputIsConsumed() {
247250
assertTrue(parsed.stream().anyMatch(e -> e.getName().contains(tag.getId())));
248251
}
249252
}
253+
254+
@Test
255+
public void testMultipleOutputParDoWithUnconsumedSideOutputAndSerializationStorageLevel() {
256+
Pipeline p = Pipeline.create();
257+
TupleTag<String> tag1 = new TupleTag<String>("tag1") {};
258+
TupleTag<String> tag2 = new TupleTag<String>("tag2") {};
259+
TupleTag<String> tag3 = new TupleTag<String>("tag3") {};
260+
261+
SparkPipelineOptions options = contextRule.createPipelineOptions();
262+
// Force serialization by setting storage level to MEMORY_AND_DISK_SER
263+
options.setStorageLevel("MEMORY_AND_DISK_SER");
264+
265+
TransformTranslator.Translator translator = new TransformTranslator.Translator();
266+
267+
PTransform<PBegin, PCollection<String>> createTransform = Create.of("foo", "bar");
268+
269+
PCollectionTuple pCollectionTuple =
270+
p.apply("Create Values", createTransform)
271+
.apply(
272+
"Multiple Output ParDo",
273+
ParDo.of(new MultiOutputDoFn(tag1, tag2, tag3))
274+
.withOutputTags(tag1, TupleTagList.of(tag2).and(tag3)));
275+
276+
// consume tag1 and tag2
277+
pCollectionTuple.get(tag1).apply("Count1", Count.globally());
278+
pCollectionTuple.get(tag2).apply("Count2", Count.globally());
279+
280+
p.replaceAll(SparkTransformOverrides.getDefaultOverrides(false));
281+
282+
EvaluationContext ctxt = new EvaluationContext(contextRule.getSparkContext(), p, options);
283+
SparkRunner.initAccumulators(options, ctxt.getSparkContext());
284+
SparkRunner.updateDependentTransforms(p, translator, ctxt);
285+
286+
// This should not throw NullPointerException
287+
p.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt));
288+
289+
// Also trigger some action on the RDD to ensure serialization happens
290+
@SuppressWarnings("unchecked")
291+
BoundedDataset<String> dataset =
292+
(BoundedDataset<String>) ctxt.borrowDataset(pCollectionTuple.get(tag1));
293+
dataset.getRDD().count();
294+
}
295+
296+
private static class MultiOutputDoFn extends DoFn<String, String> {
297+
private final TupleTag<String> tag1;
298+
private final TupleTag<String> tag2;
299+
private final TupleTag<String> tag3;
300+
301+
MultiOutputDoFn(TupleTag<String> tag1, TupleTag<String> tag2, TupleTag<String> tag3) {
302+
this.tag1 = tag1;
303+
this.tag2 = tag2;
304+
this.tag3 = tag3;
305+
}
306+
307+
@ProcessElement
308+
public void process(@Element String input, MultiOutputReceiver outputReceiver) {
309+
outputReceiver.get(tag1).output(input);
310+
outputReceiver.get(tag2).output(input);
311+
outputReceiver.get(tag3).output(input);
312+
}
313+
}
250314
}

0 commit comments

Comments
 (0)