Skip to content

Commit 4a98ded

Browse files
[FLINK-38219][table-runtime] Fix row kind in StreamingMultiJoinOperator
This closes #26886.
1 parent 5494e91 commit 4a98ded

File tree

4 files changed

+78
-19
lines changed

4 files changed

+78
-19
lines changed

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public List<TableTestProgram> programs() {
3737
MultiJoinTestPrograms.MULTI_JOIN_WITH_TIME_ATTRIBUTES_MATERIALIZATION,
3838
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_INNER_JOIN_NO_JOIN_KEY,
3939
MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_NO_COMMON_JOIN_KEY,
40-
MultiJoinTestPrograms.MULTI_JOIN_MIXED_CHANGELOG_MODES);
40+
MultiJoinTestPrograms.MULTI_JOIN_MIXED_CHANGELOG_MODES,
41+
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_WITH_CTE);
4142
}
4243
}

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinTestPrograms.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,4 +967,68 @@ public class MultiJoinTestPrograms {
967967
+ "LEFT JOIN RetractTable r ON a.id = r.ref_id "
968968
+ "LEFT JOIN UpsertTable u ON a.id = u.key_id")
969969
.build();
970+
971+
public static final TableTestProgram MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_WITH_CTE =
972+
TableTestProgram.of(
973+
"left-outer-join-with-cte",
974+
"CTE with three-way left outer join and aggregation")
975+
.setupConfig(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true)
976+
.setupTableSource(USERS_SOURCE)
977+
.setupTableSource(
978+
SourceTestStep.newBuilder("Orders")
979+
.addSchema(
980+
"user_id STRING",
981+
"order_id STRING PRIMARY KEY NOT ENFORCED",
982+
"product STRING")
983+
.addOption("changelog-mode", "I, UA,D")
984+
.producedValues(
985+
Row.ofKind(RowKind.INSERT, "2", "order2", "Product B"),
986+
Row.ofKind(
987+
RowKind.UPDATE_AFTER,
988+
"2",
989+
"order2",
990+
"Product B"),
991+
Row.ofKind(
992+
RowKind.UPDATE_AFTER,
993+
"2",
994+
"order2",
995+
"Product C"),
996+
Row.ofKind(
997+
RowKind.UPDATE_AFTER,
998+
"2",
999+
"order2",
1000+
"Product C"))
1001+
.build())
1002+
.setupTableSource(PAYMENTS_SOURCE)
1003+
.setupTableSink(
1004+
SinkTestStep.newBuilder("sink")
1005+
.addSchema("user_id STRING", "name STRING", "cnt BIGINT")
1006+
.testMaterializedData()
1007+
.consumedValues(Row.of("1", "Gus", 2), Row.of("2", "Bob", 1))
1008+
.build())
1009+
.runSql(
1010+
"INSERT INTO sink WITH "
1011+
+ " order_details AS ( "
1012+
+ " SELECT o.user_id "
1013+
+ " FROM Orders o "
1014+
+ " ), "
1015+
+ " user_elements AS ( "
1016+
+ " SELECT "
1017+
+ " u.id AS user_id "
1018+
+ " FROM ( "
1019+
+ " SELECT '2' AS id, '2' AS order_user_id "
1020+
+ " UNION ALL "
1021+
+ " SELECT '1' AS id, '2' AS order_user_id "
1022+
+ " UNION ALL "
1023+
+ " SELECT '5' AS id, '5' AS order_user_id "
1024+
+ " ) u "
1025+
+ " LEFT JOIN order_details od "
1026+
+ " ON od.user_id = u.order_user_id "
1027+
+ " ) "
1028+
+ "SELECT ue.user_id, us.name, COUNT(*) AS cnt "
1029+
+ "FROM user_elements ue "
1030+
+ "INNER JOIN Users us ON ue.user_id = us.user_id "
1031+
+ "LEFT JOIN Payments p ON ue.user_id = p.user_id "
1032+
+ "GROUP BY ue.user_id, us.name")
1033+
.build();
9701034
}

flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,12 +777,15 @@ private void handleInsertAfterInput(
777777
}
778778

779779
private void addRecordToState(RowData input, int inputId) throws Exception {
780+
final boolean isUpsert = isUpsert(input);
780781
RowData joinKey = keyExtractor.getJoinKey(input, inputId);
781782

782-
if (isRetraction(input)) {
783-
stateHandlers.get(inputId).retractRecord(joinKey, input);
784-
} else {
783+
// Always use insert so we store and retract records correctly from state
784+
input.setRowKind(RowKind.INSERT);
785+
if (isUpsert) {
785786
stateHandlers.get(inputId).addRecord(joinKey, input);
787+
} else {
788+
stateHandlers.get(inputId).retractRecord(joinKey, input);
786789
}
787790
}
788791

flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
3333
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
3434
import org.apache.flink.table.types.logical.RowType;
35-
import org.apache.flink.types.RowKind;
3635
import org.apache.flink.util.IterableIterator;
3736

3837
import javax.annotation.Nonnull;
@@ -238,15 +237,15 @@ private RowData getStateKey(RowData joinKey, RowData uniqueKey) {
238237

239238
@Override
240239
public void addRecord(@Nullable RowData joinKey, RowData record) throws Exception {
241-
RowData uniqueKey = uniqueKeySelector.getKey(record);
242-
RowData stateKey = getStateKey(joinKey, uniqueKey);
240+
final RowData uniqueKey = uniqueKeySelector.getKey(record);
241+
final RowData stateKey = getStateKey(joinKey, uniqueKey);
243242
recordState.put(stateKey, record);
244243
}
245244

246245
@Override
247246
public void retractRecord(@Nullable RowData joinKey, RowData record) throws Exception {
248-
RowData uniqueKey = uniqueKeySelector.getKey(record);
249-
RowData stateKey = getStateKey(joinKey, uniqueKey);
247+
final RowData uniqueKey = uniqueKeySelector.getKey(record);
248+
final RowData stateKey = getStateKey(joinKey, uniqueKey);
250249
recordState.remove(stateKey);
251250
}
252251

@@ -364,25 +363,18 @@ private RowData getStateKey(@Nullable RowData joinKey, RowData record) {
364363

365364
@Override
366365
public void addRecord(@Nullable RowData joinKey, RowData record) throws Exception {
367-
// Normalize RowKind for consistent state representation
368-
RowKind originalKind = record.getRowKind();
369-
record.setRowKind(RowKind.INSERT); // Normalize for key creation
370-
RowData stateKey = getStateKey(joinKey, record);
366+
final RowData stateKey = getStateKey(joinKey, record);
371367

372368
Integer currentCount = recordState.get(stateKey);
373369
if (currentCount == null) {
374370
currentCount = 0;
375371
}
376372
recordState.put(stateKey, currentCount + 1);
377-
record.setRowKind(originalKind); // Restore original RowKind
378373
}
379374

380375
@Override
381376
public void retractRecord(@Nullable RowData joinKey, RowData record) throws Exception {
382-
// Normalize RowKind for consistent state representation and lookup
383-
RowKind originalKind = record.getRowKind();
384-
record.setRowKind(RowKind.INSERT); // Normalize for key lookup
385-
RowData stateKey = getStateKey(joinKey, record);
377+
final RowData stateKey = getStateKey(joinKey, record);
386378

387379
Integer currentCount = recordState.get(stateKey);
388380
if (currentCount != null) {
@@ -392,7 +384,6 @@ public void retractRecord(@Nullable RowData joinKey, RowData record) throws Exce
392384
recordState.remove(stateKey);
393385
}
394386
}
395-
record.setRowKind(originalKind); // Restore original RowKind
396387
}
397388

398389
@Override

0 commit comments

Comments
 (0)