Skip to content

Commit 3d608d3

Browse files
committed
Fix nullness errors in SimpleDoFnRunner and DoFnInvoker
1 parent 4a756bf commit 3d608d3

File tree

2 files changed

+149
-44
lines changed

2 files changed

+149
-44
lines changed

runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java

Lines changed: 128 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.runners.core;
1919

20+
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
2021
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
2122
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
2223

@@ -63,6 +64,9 @@
6364
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
6465
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
6566
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
67+
import org.checkerframework.checker.initialization.qual.Initialized;
68+
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
69+
import org.checkerframework.checker.nullness.qual.NonNull;
6670
import org.checkerframework.checker.nullness.qual.Nullable;
6771
import org.joda.time.Duration;
6872
import org.joda.time.Instant;
@@ -78,11 +82,6 @@
7882
* @param <InputT> the type of the {@link DoFn} (main) input elements
7983
* @param <OutputT> the type of the {@link DoFn} (main) output elements
8084
*/
81-
@SuppressWarnings({
82-
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
83-
"nullness",
84-
"keyfor"
85-
}) // TODO(https://github.com/apache/beam/issues/20497)
8685
public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
8786

8887
private final PipelineOptions options;
@@ -502,13 +501,21 @@ public Object key() {
502501
}
503502

504503
@Override
505-
public Object sideInput(String tagId) {
506-
return sideInput(sideInputMapping.get(tagId));
504+
public @Nullable Object sideInput(String tagId) {
505+
PCollectionView<?> view =
506+
checkStateNotNull(sideInputMapping.get(tagId), "Side input tag %s not found", tagId);
507+
return sideInput(view);
507508
}
508509

509510
@Override
510511
public Object schemaElement(int index) {
511-
SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index);
512+
checkStateNotNull(
513+
doFnSchemaInformation,
514+
"attempt to access element via schema when no schema information provided");
515+
516+
SerializableFunction<InputT, Object> converter =
517+
(SerializableFunction<InputT, Object>)
518+
doFnSchemaInformation.getElementConverters().get(index);
512519
return converter.apply(element());
513520
}
514521

@@ -536,6 +543,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
536543

537544
@Override
538545
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
546+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
539547
return DoFnOutputReceivers.rowReceiver(this, mainOutputTag, mainOutputSchemaCoder);
540548
}
541549

@@ -575,14 +583,25 @@ public WatermarkEstimator<?> watermarkEstimator() {
575583
@Override
576584
public State state(String stateId, boolean alwaysFetched) {
577585
try {
586+
DoFnSignature.StateDeclaration stateDeclaration =
587+
checkStateNotNull(
588+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
589+
578590
StateSpec<?> spec =
579-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
591+
checkStateNotNull(
592+
(StateSpec<?>) stateDeclaration.field().get(fn),
593+
"Field %s corresponding to state id %s contained null",
594+
stateDeclaration.field(),
595+
stateId);
596+
597+
@NonNull
598+
@Initialized // unclear why checkerframework needs this help
580599
State state =
581600
stepContext
582601
.stateInternals()
583-
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
602+
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec<?>) spec));
584603
if (alwaysFetched) {
585-
return (State) ((ReadableState) state).readLater();
604+
return (State) ((ReadableState<?>) state).readLater();
586605
} else {
587606
return state;
588607
}
@@ -594,7 +613,16 @@ public State state(String stateId, boolean alwaysFetched) {
594613
@Override
595614
public Timer timer(String timerId) {
596615
try {
597-
TimerSpec spec = (TimerSpec) signature.timerDeclarations().get(timerId).field().get(fn);
616+
DoFnSignature.TimerDeclaration timerDeclaration =
617+
checkStateNotNull(
618+
signature.timerDeclarations().get(timerId), "timer not found: %s", timerId);
619+
TimerSpec spec =
620+
(TimerSpec)
621+
checkStateNotNull(
622+
timerDeclaration.field().get(fn),
623+
"Field %s corresponding to timer id %s contained null",
624+
timerDeclaration.field(),
625+
timerId);
598626
return new TimerInternalsTimer(
599627
window(), getNamespace(), timerId, spec, timestamp(), stepContext.timerInternals());
600628
} catch (IllegalAccessException e) {
@@ -605,8 +633,19 @@ public Timer timer(String timerId) {
605633
@Override
606634
public TimerMap timerFamily(String timerFamilyId) {
607635
try {
636+
DoFnSignature.TimerFamilyDeclaration timerFamilyDeclaration =
637+
checkStateNotNull(
638+
signature.timerFamilyDeclarations().get(timerFamilyId),
639+
"timer family not found: %s",
640+
timerFamilyId);
641+
608642
TimerSpec spec =
609-
(TimerSpec) signature.timerFamilyDeclarations().get(timerFamilyId).field().get(fn);
643+
(TimerSpec)
644+
checkStateNotNull(
645+
timerFamilyDeclaration.field().get(fn),
646+
"Field %s corresponding to timer family id %s contained null",
647+
timerFamilyDeclaration.field(),
648+
timerFamilyId);
610649
return new TimerInternalsTimerMap(
611650
timerFamilyId,
612651
window(),
@@ -760,6 +799,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
760799

761800
@Override
762801
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
802+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
763803
return DoFnOutputReceivers.rowReceiver(this, mainOutputTag, mainOutputSchemaCoder);
764804
}
765805

@@ -797,8 +837,18 @@ public WatermarkEstimator<?> watermarkEstimator() {
797837
@Override
798838
public State state(String stateId, boolean alwaysFetched) {
799839
try {
840+
DoFnSignature.StateDeclaration stateDeclaration =
841+
checkStateNotNull(
842+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
843+
800844
StateSpec<?> spec =
801-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
845+
checkStateNotNull(
846+
(StateSpec<?>) stateDeclaration.field().get(fn),
847+
"Field %s corresponding to state id %s contained null",
848+
stateDeclaration.field(),
849+
stateId);
850+
851+
@NonNull
802852
State state =
803853
stepContext
804854
.stateInternals()
@@ -816,7 +866,16 @@ public State state(String stateId, boolean alwaysFetched) {
816866
@Override
817867
public Timer timer(String timerId) {
818868
try {
819-
TimerSpec spec = (TimerSpec) signature.timerDeclarations().get(timerId).field().get(fn);
869+
DoFnSignature.TimerDeclaration timerDeclaration =
870+
checkStateNotNull(
871+
signature.timerDeclarations().get(timerId), "timer not found: %s", timerId);
872+
TimerSpec spec =
873+
(TimerSpec)
874+
checkStateNotNull(
875+
timerDeclaration.field().get(fn),
876+
"Field %s corresponding to timer id %s contained null",
877+
timerDeclaration.field(),
878+
timerId);
820879
return new TimerInternalsTimer(
821880
window, getNamespace(), timerId, spec, timestamp(), stepContext.timerInternals());
822881
} catch (IllegalAccessException e) {
@@ -827,8 +886,18 @@ public Timer timer(String timerId) {
827886
@Override
828887
public TimerMap timerFamily(String timerFamilyId) {
829888
try {
889+
DoFnSignature.TimerFamilyDeclaration timerFamilyDeclaration =
890+
checkStateNotNull(
891+
signature.timerFamilyDeclarations().get(timerFamilyId),
892+
"timer family not found: %s",
893+
timerFamilyId);
830894
TimerSpec spec =
831-
(TimerSpec) signature.timerFamilyDeclarations().get(timerFamilyId).field().get(fn);
895+
(TimerSpec)
896+
checkStateNotNull(
897+
timerFamilyDeclaration.field().get(fn),
898+
"Field %s corresponding to timer family id %s contained null",
899+
timerFamilyDeclaration.field(),
900+
timerFamilyId);
832901
return new TimerInternalsTimerMap(
833902
timerFamilyId,
834903
window(),
@@ -1007,6 +1076,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
10071076

10081077
@Override
10091078
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
1079+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
10101080
return DoFnOutputReceivers.rowReceiver(this, mainOutputTag, mainOutputSchemaCoder);
10111081
}
10121082

@@ -1044,14 +1114,23 @@ public WatermarkEstimator<?> watermarkEstimator() {
10441114
@Override
10451115
public State state(String stateId, boolean alwaysFetched) {
10461116
try {
1117+
DoFnSignature.StateDeclaration stateDeclaration =
1118+
checkStateNotNull(
1119+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
10471120
StateSpec<?> spec =
1048-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
1121+
checkStateNotNull(
1122+
(StateSpec<?>) stateDeclaration.field().get(fn),
1123+
"Field %s corresponding to state id %s contained null",
1124+
stateDeclaration.field(),
1125+
stateId);
1126+
@NonNull
1127+
@Initialized // unclear why checkerframework needs this help
10491128
State state =
10501129
stepContext
10511130
.stateInternals()
1052-
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
1131+
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec<?>) spec));
10531132
if (alwaysFetched) {
1054-
return (State) ((ReadableState) state).readLater();
1133+
return (State) ((ReadableState<?>) state).readLater();
10551134
} else {
10561135
return state;
10571136
}
@@ -1138,7 +1217,7 @@ private class TimerInternalsTimer implements Timer {
11381217
private final String timerId;
11391218
private final String timerFamilyId;
11401219
private final TimerSpec spec;
1141-
private Instant target;
1220+
private @MonotonicNonNull Instant target;
11421221
private @Nullable Instant outputTimestamp;
11431222
private boolean noOutputTimestamp;
11441223
private final Instant elementInputTimestamp;
@@ -1256,23 +1335,26 @@ public Timer withNoOutputTimestamp() {
12561335
* <li>The current element timestamp for other time domains.
12571336
*/
12581337
private void setAndVerifyOutputTimestamp() {
1338+
checkStateNotNull(target, "attempt to set outputTimestamp before setting target firing time");
12591339
if (outputTimestamp != null) {
1340+
// setting to local var so checkerframework knows that method calls will not mutate it
1341+
Instant timestampToValidate = outputTimestamp;
12601342
Instant lowerBound;
12611343
try {
12621344
lowerBound = elementInputTimestamp.minus(fn.getAllowedTimestampSkew());
12631345
} catch (ArithmeticException e) {
12641346
lowerBound = BoundedWindow.TIMESTAMP_MIN_VALUE;
12651347
}
1266-
if (outputTimestamp.isBefore(lowerBound)
1267-
|| outputTimestamp.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
1348+
if (timestampToValidate.isBefore(lowerBound)
1349+
|| timestampToValidate.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
12681350
throw new IllegalArgumentException(
12691351
String.format(
12701352
"Cannot output timer with output timestamp %s. Output timestamps must be no "
12711353
+ "earlier than the timestamp of the current input or timer (%s) minus the "
12721354
+ "allowed skew (%s) and no later than %s. See the "
12731355
+ "DoFn#getAllowedTimestampSkew() Javadoc for details on changing the "
12741356
+ "allowed skew.",
1275-
outputTimestamp,
1357+
timestampToValidate,
12761358
elementInputTimestamp,
12771359
fn.getAllowedTimestampSkew().getMillis() >= Integer.MAX_VALUE
12781360
? fn.getAllowedTimestampSkew()
@@ -1289,6 +1371,9 @@ private void setAndVerifyOutputTimestamp() {
12891371
// the element (or timer) setting this timer.
12901372
outputTimestamp = elementInputTimestamp;
12911373
}
1374+
1375+
// Now it has been set for all cases other than this.noOutputTimestamp == true, and there are
1376+
// further validations
12921377
if (outputTimestamp != null) {
12931378
Instant windowExpiry = LateDataUtils.garbageCollectionTime(window, allowedLateness);
12941379
if (TimeDomain.EVENT_TIME.equals(spec.getTimeDomain())) {
@@ -1323,6 +1408,12 @@ private void setAndVerifyOutputTimestamp() {
13231408
* user has no way to compute a good choice of time.
13241409
*/
13251410
private void setUnderlyingTimer() {
1411+
checkStateNotNull(
1412+
outputTimestamp,
1413+
"internal error: null outputTimestamp: must be populated by setAndVerifyOutputTimestamp()");
1414+
checkStateNotNull(
1415+
target,
1416+
"internal error: attempt to set internal timer when target timestamp not yet set");
13261417
timerInternals.setTimer(
13271418
namespace, timerId, timerFamilyId, target, outputTimestamp, spec.getTimeDomain());
13281419
}
@@ -1339,7 +1430,9 @@ private Instant getCurrentTime(TimeDomain timeDomain) {
13391430
case PROCESSING_TIME:
13401431
return timerInternals.currentProcessingTime();
13411432
case SYNCHRONIZED_PROCESSING_TIME:
1342-
return timerInternals.currentSynchronizedProcessingTime();
1433+
return checkStateNotNull(
1434+
timerInternals.currentSynchronizedProcessingTime(),
1435+
"internal error: requested SYNCHRONIZED_PROCESSING_TIME but it was null");
13431436
default:
13441437
throw new IllegalStateException(
13451438
String.format("Timer created for unknown time domain %s", spec.getTimeDomain()));
@@ -1389,19 +1482,17 @@ public void set(String timerId, Instant absoluteTime) {
13891482

13901483
@Override
13911484
public Timer get(String timerId) {
1392-
if (timers.get(timerId) == null) {
1393-
Timer timer =
1394-
new TimerInternalsTimer(
1395-
window,
1396-
namespace,
1397-
timerId,
1398-
timerFamilyId,
1399-
spec,
1400-
elementInputTimestamp,
1401-
timerInternals);
1402-
timers.put(timerId, timer);
1403-
}
1404-
return timers.get(timerId);
1485+
return timers.computeIfAbsent(
1486+
timerId,
1487+
id ->
1488+
new TimerInternalsTimer(
1489+
window,
1490+
namespace,
1491+
id,
1492+
timerFamilyId,
1493+
spec,
1494+
elementInputTimestamp,
1495+
timerInternals));
14051496
}
14061497
}
14071498
}

0 commit comments

Comments
 (0)