Skip to content

Commit fbc7d0a

Browse files
[FLINK-38230][table-planner] Add visitMultiJoin to StreamNDUPlanVisitor
This closes #26894.
1 parent 00241a6 commit fbc7d0a

File tree

7 files changed

+461
-222
lines changed

7 files changed

+461
-222
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMultiJoin.java

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ public class StreamPhysicalMultiJoin extends AbstractRelNode implements StreamPh
7979
private final @Nullable RexNode postJoinFilter;
8080
private final List<RelHint> hints;
8181

82+
// Cached derived properties to avoid recomputation
83+
private @Nullable RexNode multiJoinCondition;
84+
private @Nullable List<List<int[]>> inputUniqueKeys;
85+
8286
public StreamPhysicalMultiJoin(
8387
final RelOptCluster cluster,
8488
final RelTraitSet traitSet,
@@ -101,6 +105,8 @@ public StreamPhysicalMultiJoin(
101105
this.postJoinFilter = postJoinFilter;
102106
this.hints = hints;
103107
this.keyExtractor = keyExtractor;
108+
this.multiJoinCondition = getMultiJoinCondition();
109+
this.inputUniqueKeys = getUniqueKeysForInputs();
104110
}
105111

106112
@Override
@@ -119,6 +125,9 @@ public void replaceInput(final int ordinalInParent, final RelNode p) {
119125
final List<RelNode> newInputs = new ArrayList<>(inputs);
120126
newInputs.set(ordinalInParent, p);
121127
this.inputs = List.copyOf(newInputs);
128+
// Invalidate cached derived properties since inputs changed
129+
this.multiJoinCondition = null;
130+
this.inputUniqueKeys = null;
122131
recomputeDigest();
123132
}
124133

@@ -166,18 +175,18 @@ protected RelDataType deriveRowType() {
166175

167176
@Override
168177
public ExecNode<?> translateToExecNode() {
169-
final RexNode multiJoinCondition = createMultiJoinCondition();
170-
final List<List<int[]>> inputUniqueKeys = getUniqueKeysForInputs();
178+
final RexNode multijoinCondition = getMultiJoinCondition();
179+
final List<List<int[]>> localInputUniqueKeys = getUniqueKeysForInputs();
171180
final List<FlinkJoinType> execJoinTypes = getExecJoinTypes();
172181
final List<InputProperty> inputProperties = createInputProperties();
173182

174183
return new StreamExecMultiJoin(
175184
unwrapTableConfig(this),
176185
execJoinTypes,
177186
joinConditions,
178-
multiJoinCondition,
187+
multijoinCondition,
179188
joinAttributeMap,
180-
inputUniqueKeys,
189+
localInputUniqueKeys,
181190
Collections.emptyMap(), // TODO Enable hint-based state ttl. See ticket
182191
// TODO https://issues.apache.org/jira/browse/FLINK-37936
183192
inputProperties,
@@ -187,28 +196,43 @@ public ExecNode<?> translateToExecNode() {
187196

188197
private RexNode createMultiJoinCondition() {
189198
final List<RexNode> conjunctions = new ArrayList<>();
199+
200+
for (RexNode joinCondition : joinConditions) {
201+
if (joinCondition != null) {
202+
conjunctions.add(joinCondition);
203+
}
204+
}
205+
190206
conjunctions.add(joinFilter);
207+
191208
if (postJoinFilter != null) {
192209
conjunctions.add(postJoinFilter);
193210
}
211+
194212
return RexUtil.composeConjunction(getCluster().getRexBuilder(), conjunctions, true);
195213
}
196214

197-
private List<List<int[]>> getUniqueKeysForInputs() {
198-
return inputs.stream()
199-
.map(
200-
input -> {
201-
final Set<ImmutableBitSet> uniqueKeys = getUniqueKeys(input);
202-
203-
if (uniqueKeys == null) {
204-
return Collections.<int[]>emptyList();
205-
}
206-
207-
return uniqueKeys.stream()
208-
.map(ImmutableBitSet::toArray)
209-
.collect(Collectors.toList());
210-
})
211-
.collect(Collectors.toList());
215+
public List<List<int[]>> getUniqueKeysForInputs() {
216+
if (inputUniqueKeys == null) {
217+
final List<List<int[]>> computed =
218+
inputs.stream()
219+
.map(
220+
input -> {
221+
final Set<ImmutableBitSet> uniqueKeys =
222+
getUniqueKeys(input);
223+
224+
if (uniqueKeys == null) {
225+
return Collections.<int[]>emptyList();
226+
}
227+
228+
return uniqueKeys.stream()
229+
.map(ImmutableBitSet::toArray)
230+
.collect(Collectors.toList());
231+
})
232+
.collect(Collectors.toList());
233+
inputUniqueKeys = Collections.unmodifiableList(computed);
234+
}
235+
return inputUniqueKeys;
212236
}
213237

214238
private @Nullable Set<ImmutableBitSet> getUniqueKeys(RelNode input) {
@@ -217,6 +241,13 @@ private List<List<int[]>> getUniqueKeysForInputs() {
217241
return fmq.getUniqueKeys(input);
218242
}
219243

244+
public RexNode getMultiJoinCondition() {
245+
if (multiJoinCondition == null) {
246+
multiJoinCondition = createMultiJoinCondition();
247+
}
248+
return multiJoinCondition;
249+
}
250+
220251
private List<FlinkJoinType> getExecJoinTypes() {
221252
return joinTypes.stream()
222253
.map(
@@ -256,8 +287,8 @@ public List<JoinRelType> getJoinTypes() {
256287
*/
257288
public boolean inputUniqueKeyContainsCommonJoinKey(int inputId) {
258289
final RelNode input = getInputs().get(inputId);
259-
final Set<ImmutableBitSet> inputUniqueKeys = getUniqueKeys(input);
260-
if (inputUniqueKeys == null || inputUniqueKeys.isEmpty()) {
290+
final Set<ImmutableBitSet> inputUniqueKeysSet = getUniqueKeys(input);
291+
if (inputUniqueKeysSet == null || inputUniqueKeysSet.isEmpty()) {
261292
return false;
262293
}
263294

@@ -267,7 +298,8 @@ public boolean inputUniqueKeyContainsCommonJoinKey(int inputId) {
267298
}
268299

269300
final ImmutableBitSet commonJoinKeys = ImmutableBitSet.of(commonJoinKeyIndices);
270-
return inputUniqueKeys.stream().anyMatch(uniqueKey -> uniqueKey.contains(commonJoinKeys));
301+
return inputUniqueKeysSet.stream()
302+
.anyMatch(uniqueKey -> uniqueKey.contains(commonJoinKeys));
271303
}
272304

273305
private List<InputProperty> createInputProperties() {

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLookupJoin;
4646
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMatch;
4747
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMiniBatchAssigner;
48+
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
4849
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalOverAggregateBase;
4950
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRank;
5051
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel;
@@ -77,6 +78,7 @@
7778

7879
import org.apache.calcite.rel.RelNode;
7980
import org.apache.calcite.rel.core.AggregateCall;
81+
import org.apache.calcite.rel.core.JoinRelType;
8082
import org.apache.calcite.rel.type.RelDataType;
8183
import org.apache.calcite.rex.RexNode;
8284
import org.apache.calcite.rex.RexProgram;
@@ -180,6 +182,8 @@ public StreamPhysicalRel visit(
180182
return visitExpand((StreamPhysicalExpand) rel, requireDeterminism);
181183
} else if (rel instanceof CommonPhysicalJoin) {
182184
return visitJoin((CommonPhysicalJoin) rel, requireDeterminism);
185+
} else if (rel instanceof StreamPhysicalMultiJoin) {
186+
return visitMultiJoin((StreamPhysicalMultiJoin) rel, requireDeterminism);
183187
} else if (rel instanceof StreamPhysicalOverAggregateBase) {
184188
return visitOverAggregate((StreamPhysicalOverAggregateBase) rel, requireDeterminism);
185189
} else if (rel instanceof StreamPhysicalRank) {
@@ -621,6 +625,102 @@ private StreamPhysicalRel visitJoin(
621625
join.isSemiJoin());
622626
}
623627

628+
/**
629+
* Multi-join determinism handling, mirroring the binary join logic:
630+
*
631+
* <p>If all inputs are insert-only and every join is INNER, the output is insert-only → no
632+
* determinism required downstream.
633+
*
634+
* <p>Otherwise the combined join condition must be deterministic, and we propagate per-input
635+
* determinism:
636+
*
637+
* <ul>
638+
* <li>If an input can produce updates, and we cannot guarantee uniqueness, we must require
639+
* determinism for the entire input row (retract-by-row correctness).
640+
* <li>If uniqueness is guaranteed, we pass through the part of the requirement that belongs
641+
* to that input.
642+
* </ul>
643+
*/
644+
private StreamPhysicalRel visitMultiJoin(
645+
final StreamPhysicalMultiJoin multiJoin, final ImmutableBitSet requireDeterminism) {
646+
final List<RelNode> inputs = multiJoin.getInputs();
647+
final boolean allInputsInsertOnly =
648+
inputs.stream().allMatch(in -> inputInsertOnly((StreamPhysicalRel) in));
649+
final boolean allInner =
650+
multiJoin.getJoinTypes().stream().allMatch(t -> t == JoinRelType.INNER);
651+
652+
// Fast path: pure insert-only inner join produces insert-only output -> nothing to require.
653+
if (allInputsInsertOnly && allInner) {
654+
return transmitDeterminismRequirement(multiJoin, NO_REQUIRED_DETERMINISM);
655+
}
656+
657+
// Output may carry updates (some input updates or some non-inner join): condition must be
658+
// deterministic.
659+
final RexNode multiJoinCondition = multiJoin.getMultiJoinCondition();
660+
if (multiJoinCondition != null) {
661+
final Optional<String> ndCall =
662+
FlinkRexUtil.getNonDeterministicCallName(multiJoinCondition);
663+
ndCall.ifPresent(
664+
s -> throwNonDeterministicConditionError(s, multiJoinCondition, multiJoin));
665+
}
666+
667+
// Output may carry updates: we need to propagate determinism requirements to inputs.
668+
final List<RelNode> newInputs = rewriteMultiJoinInputs(multiJoin, requireDeterminism);
669+
670+
return (StreamPhysicalRel) multiJoin.copy(multiJoin.getTraitSet(), newInputs);
671+
}
672+
673+
private ImmutableBitSet projectToInput(
674+
final ImmutableBitSet globalRequired, final int inputStart, final int inputFieldCount) {
675+
final List<Integer> local =
676+
globalRequired.toList().stream()
677+
.filter(idx -> idx >= inputStart && idx < inputStart + inputFieldCount)
678+
.map(idx -> idx - inputStart)
679+
.collect(Collectors.toList());
680+
return ImmutableBitSet.of(local);
681+
}
682+
683+
private ImmutableBitSet requiredForUpdatingMultiJoinInput(
684+
final StreamPhysicalMultiJoin multiJoin,
685+
final int inputIndex,
686+
final ImmutableBitSet localRequired,
687+
final int inputFieldCount) {
688+
final List<int[]> uniqueKeys = multiJoin.getUniqueKeysForInputs().get(inputIndex);
689+
final boolean hasUniqueKey = !uniqueKeys.isEmpty();
690+
691+
if (hasUniqueKey) {
692+
return localRequired;
693+
}
694+
// Without uniqueness guarantees we must retract by entire row for correctness.
695+
return ImmutableBitSet.range(inputFieldCount);
696+
}
697+
698+
private List<RelNode> rewriteMultiJoinInputs(
699+
final StreamPhysicalMultiJoin multiJoin, final ImmutableBitSet requireDeterminism) {
700+
final List<RelNode> inputs = multiJoin.getInputs();
701+
final List<RelNode> newInputs = new ArrayList<>(inputs.size());
702+
int fieldStartOffset = 0;
703+
for (int i = 0; i < inputs.size(); i++) {
704+
final StreamPhysicalRel input = (StreamPhysicalRel) inputs.get(i);
705+
final int inputFieldCount = input.getRowType().getFieldCount();
706+
707+
final ImmutableBitSet localRequired =
708+
projectToInput(requireDeterminism, fieldStartOffset, inputFieldCount);
709+
710+
final ImmutableBitSet inputRequired =
711+
inputInsertOnly(input)
712+
? NO_REQUIRED_DETERMINISM
713+
: requiredForUpdatingMultiJoinInput(
714+
multiJoin, i, localRequired, inputFieldCount);
715+
716+
final ImmutableBitSet finalRequired =
717+
requireDeterminismExcludeUpsertKey(input, inputRequired);
718+
newInputs.add(visit(input, finalRequired));
719+
fieldStartOffset += inputFieldCount;
720+
}
721+
return newInputs;
722+
}
723+
624724
private StreamPhysicalRel visitOverAggregate(
625725
final StreamPhysicalOverAggregateBase overAgg,
626726
final ImmutableBitSet requireDeterminism) {

0 commit comments

Comments
 (0)