Skip to content

Commit f07ca33

Browse files
committed
Fix nullness errors in SimpleDoFnRunner and DoFnInvoker
1 parent 6c99f7c commit f07ca33

File tree

2 files changed

+149
-44
lines changed

2 files changed

+149
-44
lines changed

runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java

Lines changed: 128 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.beam.runners.core;
1919

20+
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
2021
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
2122
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
2223

@@ -66,6 +67,9 @@
6667
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
6768
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
6869
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
70+
import org.checkerframework.checker.initialization.qual.Initialized;
71+
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
72+
import org.checkerframework.checker.nullness.qual.NonNull;
6973
import org.checkerframework.checker.nullness.qual.Nullable;
7074
import org.joda.time.Duration;
7175
import org.joda.time.Instant;
@@ -81,11 +85,6 @@
8185
* @param <InputT> the type of the {@link DoFn} (main) input elements
8286
* @param <OutputT> the type of the {@link DoFn} (main) output elements
8387
*/
84-
@SuppressWarnings({
85-
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
86-
"nullness",
87-
"keyfor"
88-
}) // TODO(https://github.com/apache/beam/issues/20497)
8988
public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
9089

9190
private final PipelineOptions options;
@@ -527,13 +526,21 @@ public Object key() {
527526
}
528527

529528
@Override
530-
public Object sideInput(String tagId) {
531-
return sideInput(sideInputMapping.get(tagId));
529+
public @Nullable Object sideInput(String tagId) {
530+
PCollectionView<?> view =
531+
checkStateNotNull(sideInputMapping.get(tagId), "Side input tag %s not found", tagId);
532+
return sideInput(view);
532533
}
533534

534535
@Override
535536
public Object schemaElement(int index) {
536-
SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index);
537+
checkStateNotNull(
538+
doFnSchemaInformation,
539+
"attempt to access element via schema when no schema information provided");
540+
541+
SerializableFunction<InputT, Object> converter =
542+
(SerializableFunction<InputT, Object>)
543+
doFnSchemaInformation.getElementConverters().get(index);
537544
return converter.apply(element());
538545
}
539546

@@ -561,6 +568,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
561568

562569
@Override
563570
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
571+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
564572
return DoFnOutputReceivers.rowReceiver(
565573
this, builderSupplier, mainOutputTag, mainOutputSchemaCoder);
566574
}
@@ -601,14 +609,25 @@ public WatermarkEstimator<?> watermarkEstimator() {
601609
@Override
602610
public State state(String stateId, boolean alwaysFetched) {
603611
try {
612+
DoFnSignature.StateDeclaration stateDeclaration =
613+
checkStateNotNull(
614+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
615+
604616
StateSpec<?> spec =
605-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
617+
checkStateNotNull(
618+
(StateSpec<?>) stateDeclaration.field().get(fn),
619+
"Field %s corresponding to state id %s contained null",
620+
stateDeclaration.field(),
621+
stateId);
622+
623+
@NonNull
624+
@Initialized // unclear why checkerframework needs this help
606625
State state =
607626
stepContext
608627
.stateInternals()
609-
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
628+
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec<?>) spec));
610629
if (alwaysFetched) {
611-
return (State) ((ReadableState) state).readLater();
630+
return (State) ((ReadableState<?>) state).readLater();
612631
} else {
613632
return state;
614633
}
@@ -620,7 +639,16 @@ public State state(String stateId, boolean alwaysFetched) {
620639
@Override
621640
public Timer timer(String timerId) {
622641
try {
623-
TimerSpec spec = (TimerSpec) signature.timerDeclarations().get(timerId).field().get(fn);
642+
DoFnSignature.TimerDeclaration timerDeclaration =
643+
checkStateNotNull(
644+
signature.timerDeclarations().get(timerId), "timer not found: %s", timerId);
645+
TimerSpec spec =
646+
(TimerSpec)
647+
checkStateNotNull(
648+
timerDeclaration.field().get(fn),
649+
"Field %s corresponding to timer id %s contained null",
650+
timerDeclaration.field(),
651+
timerId);
624652
return new TimerInternalsTimer(
625653
window(), getNamespace(), timerId, spec, timestamp(), stepContext.timerInternals());
626654
} catch (IllegalAccessException e) {
@@ -631,8 +659,19 @@ public Timer timer(String timerId) {
631659
@Override
632660
public TimerMap timerFamily(String timerFamilyId) {
633661
try {
662+
DoFnSignature.TimerFamilyDeclaration timerFamilyDeclaration =
663+
checkStateNotNull(
664+
signature.timerFamilyDeclarations().get(timerFamilyId),
665+
"timer family not found: %s",
666+
timerFamilyId);
667+
634668
TimerSpec spec =
635-
(TimerSpec) signature.timerFamilyDeclarations().get(timerFamilyId).field().get(fn);
669+
(TimerSpec)
670+
checkStateNotNull(
671+
timerFamilyDeclaration.field().get(fn),
672+
"Field %s corresponding to timer family id %s contained null",
673+
timerFamilyDeclaration.field(),
674+
timerFamilyId);
636675
return new TimerInternalsTimerMap(
637676
timerFamilyId,
638677
window(),
@@ -794,6 +833,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
794833

795834
@Override
796835
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
836+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
797837
return DoFnOutputReceivers.rowReceiver(
798838
this, builderSupplier, mainOutputTag, mainOutputSchemaCoder);
799839
}
@@ -833,8 +873,18 @@ public WatermarkEstimator<?> watermarkEstimator() {
833873
@Override
834874
public State state(String stateId, boolean alwaysFetched) {
835875
try {
876+
DoFnSignature.StateDeclaration stateDeclaration =
877+
checkStateNotNull(
878+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
879+
836880
StateSpec<?> spec =
837-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
881+
checkStateNotNull(
882+
(StateSpec<?>) stateDeclaration.field().get(fn),
883+
"Field %s corresponding to state id %s contained null",
884+
stateDeclaration.field(),
885+
stateId);
886+
887+
@NonNull
838888
State state =
839889
stepContext
840890
.stateInternals()
@@ -852,7 +902,16 @@ public State state(String stateId, boolean alwaysFetched) {
852902
@Override
853903
public Timer timer(String timerId) {
854904
try {
855-
TimerSpec spec = (TimerSpec) signature.timerDeclarations().get(timerId).field().get(fn);
905+
DoFnSignature.TimerDeclaration timerDeclaration =
906+
checkStateNotNull(
907+
signature.timerDeclarations().get(timerId), "timer not found: %s", timerId);
908+
TimerSpec spec =
909+
(TimerSpec)
910+
checkStateNotNull(
911+
timerDeclaration.field().get(fn),
912+
"Field %s corresponding to timer id %s contained null",
913+
timerDeclaration.field(),
914+
timerId);
856915
return new TimerInternalsTimer(
857916
window, getNamespace(), timerId, spec, timestamp(), stepContext.timerInternals());
858917
} catch (IllegalAccessException e) {
@@ -863,8 +922,18 @@ public Timer timer(String timerId) {
863922
@Override
864923
public TimerMap timerFamily(String timerFamilyId) {
865924
try {
925+
DoFnSignature.TimerFamilyDeclaration timerFamilyDeclaration =
926+
checkStateNotNull(
927+
signature.timerFamilyDeclarations().get(timerFamilyId),
928+
"timer family not found: %s",
929+
timerFamilyId);
866930
TimerSpec spec =
867-
(TimerSpec) signature.timerFamilyDeclarations().get(timerFamilyId).field().get(fn);
931+
(TimerSpec)
932+
checkStateNotNull(
933+
timerFamilyDeclaration.field().get(fn),
934+
"Field %s corresponding to timer family id %s contained null",
935+
timerFamilyDeclaration.field(),
936+
timerFamilyId);
868937
return new TimerInternalsTimerMap(
869938
timerFamilyId,
870939
window(),
@@ -1058,6 +1127,7 @@ public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
10581127

10591128
@Override
10601129
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
1130+
checkStateNotNull(mainOutputSchemaCoder, "cannot provide row receiver without schema coder");
10611131
return DoFnOutputReceivers.rowReceiver(
10621132
this, builderSupplier, mainOutputTag, mainOutputSchemaCoder);
10631133
}
@@ -1096,14 +1166,23 @@ public WatermarkEstimator<?> watermarkEstimator() {
10961166
@Override
10971167
public State state(String stateId, boolean alwaysFetched) {
10981168
try {
1169+
DoFnSignature.StateDeclaration stateDeclaration =
1170+
checkStateNotNull(
1171+
signature.stateDeclarations().get(stateId), "state not found: %s", stateId);
10991172
StateSpec<?> spec =
1100-
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
1173+
checkStateNotNull(
1174+
(StateSpec<?>) stateDeclaration.field().get(fn),
1175+
"Field %s corresponding to state id %s contained null",
1176+
stateDeclaration.field(),
1177+
stateId);
1178+
@NonNull
1179+
@Initialized // unclear why checkerframework needs this help
11011180
State state =
11021181
stepContext
11031182
.stateInternals()
1104-
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
1183+
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec<?>) spec));
11051184
if (alwaysFetched) {
1106-
return (State) ((ReadableState) state).readLater();
1185+
return (State) ((ReadableState<?>) state).readLater();
11071186
} else {
11081187
return state;
11091188
}
@@ -1195,7 +1274,7 @@ private class TimerInternalsTimer implements Timer {
11951274
private final String timerId;
11961275
private final String timerFamilyId;
11971276
private final TimerSpec spec;
1198-
private Instant target;
1277+
private @MonotonicNonNull Instant target;
11991278
private @Nullable Instant outputTimestamp;
12001279
private boolean noOutputTimestamp;
12011280
private final Instant elementInputTimestamp;
@@ -1313,23 +1392,26 @@ public Timer withNoOutputTimestamp() {
13131392
* <li>The current element timestamp for other time domains.
13141393
*/
13151394
private void setAndVerifyOutputTimestamp() {
1395+
checkStateNotNull(target, "attempt to set outputTimestamp before setting target firing time");
13161396
if (outputTimestamp != null) {
1397+
// setting to local var so checkerframework knows that method calls will not mutate it
1398+
Instant timestampToValidate = outputTimestamp;
13171399
Instant lowerBound;
13181400
try {
13191401
lowerBound = elementInputTimestamp.minus(fn.getAllowedTimestampSkew());
13201402
} catch (ArithmeticException e) {
13211403
lowerBound = BoundedWindow.TIMESTAMP_MIN_VALUE;
13221404
}
1323-
if (outputTimestamp.isBefore(lowerBound)
1324-
|| outputTimestamp.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
1405+
if (timestampToValidate.isBefore(lowerBound)
1406+
|| timestampToValidate.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
13251407
throw new IllegalArgumentException(
13261408
String.format(
13271409
"Cannot output timer with output timestamp %s. Output timestamps must be no "
13281410
+ "earlier than the timestamp of the current input or timer (%s) minus the "
13291411
+ "allowed skew (%s) and no later than %s. See the "
13301412
+ "DoFn#getAllowedTimestampSkew() Javadoc for details on changing the "
13311413
+ "allowed skew.",
1332-
outputTimestamp,
1414+
timestampToValidate,
13331415
elementInputTimestamp,
13341416
fn.getAllowedTimestampSkew().getMillis() >= Integer.MAX_VALUE
13351417
? fn.getAllowedTimestampSkew()
@@ -1346,6 +1428,9 @@ private void setAndVerifyOutputTimestamp() {
13461428
// the element (or timer) setting this timer.
13471429
outputTimestamp = elementInputTimestamp;
13481430
}
1431+
1432+
// Now it has been set for all cases other than this.noOutputTimestamp == true, and there are
1433+
// further validations
13491434
if (outputTimestamp != null) {
13501435
Instant windowExpiry = LateDataUtils.garbageCollectionTime(window, allowedLateness);
13511436
if (TimeDomain.EVENT_TIME.equals(spec.getTimeDomain())) {
@@ -1380,6 +1465,12 @@ private void setAndVerifyOutputTimestamp() {
13801465
* user has no way to compute a good choice of time.
13811466
*/
13821467
private void setUnderlyingTimer() {
1468+
checkStateNotNull(
1469+
outputTimestamp,
1470+
"internal error: null outputTimestamp: must be populated by setAndVerifyOutputTimestamp()");
1471+
checkStateNotNull(
1472+
target,
1473+
"internal error: attempt to set internal timer when target timestamp not yet set");
13831474
timerInternals.setTimer(
13841475
namespace, timerId, timerFamilyId, target, outputTimestamp, spec.getTimeDomain());
13851476
}
@@ -1396,7 +1487,9 @@ private Instant getCurrentTime(TimeDomain timeDomain) {
13961487
case PROCESSING_TIME:
13971488
return timerInternals.currentProcessingTime();
13981489
case SYNCHRONIZED_PROCESSING_TIME:
1399-
return timerInternals.currentSynchronizedProcessingTime();
1490+
return checkStateNotNull(
1491+
timerInternals.currentSynchronizedProcessingTime(),
1492+
"internal error: requested SYNCHRONIZED_PROCESSING_TIME but it was null");
14001493
default:
14011494
throw new IllegalStateException(
14021495
String.format("Timer created for unknown time domain %s", spec.getTimeDomain()));
@@ -1446,19 +1539,17 @@ public void set(String timerId, Instant absoluteTime) {
14461539

14471540
@Override
14481541
public Timer get(String timerId) {
1449-
if (timers.get(timerId) == null) {
1450-
Timer timer =
1451-
new TimerInternalsTimer(
1452-
window,
1453-
namespace,
1454-
timerId,
1455-
timerFamilyId,
1456-
spec,
1457-
elementInputTimestamp,
1458-
timerInternals);
1459-
timers.put(timerId, timer);
1460-
}
1461-
return timers.get(timerId);
1542+
return timers.computeIfAbsent(
1543+
timerId,
1544+
id ->
1545+
new TimerInternalsTimer(
1546+
window,
1547+
namespace,
1548+
id,
1549+
timerFamilyId,
1550+
spec,
1551+
elementInputTimestamp,
1552+
timerInternals));
14621553
}
14631554
}
14641555
}

0 commit comments

Comments
 (0)