Skip to content

Commit e102abd

Browse files
authored
[flink] optimize flink Data-Evolution-Merge-Into (apache#7324)
This PR optimize Data-Evolution-Merge-Into in several aspects: 1. Introduce a MergeIntoUpdateChecker to check if some global-indexed columns are updated (This is same as Spark's implementation) 2. Use calcite to rename target table (current implementation is based on regex, which is very unstable) 3. Use calcite to find _row_id field (if exists) in source table. We can eliminate join process.
1 parent be4d1a3 commit e102abd

File tree

6 files changed

+598
-96
lines changed

6 files changed

+598
-96
lines changed

paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/action/DataEvolutionMergeIntoAction.java

Lines changed: 192 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
import org.apache.paimon.flink.LogicalTypeConversion;
2626
import org.apache.paimon.flink.dataevolution.DataEvolutionPartialWriteOperator;
2727
import org.apache.paimon.flink.dataevolution.FirstRowIdAssigner;
28+
import org.apache.paimon.flink.dataevolution.MergeIntoUpdateChecker;
2829
import org.apache.paimon.flink.sink.Committable;
2930
import org.apache.paimon.flink.sink.CommittableTypeInfo;
3031
import org.apache.paimon.flink.sink.CommitterOperatorFactory;
3132
import org.apache.paimon.flink.sink.NoopCommittableStateManager;
3233
import org.apache.paimon.flink.sink.StoreCommitter;
3334
import org.apache.paimon.flink.sorter.SortOperator;
35+
import org.apache.paimon.flink.utils.FlinkCalciteClasses;
3436
import org.apache.paimon.flink.utils.InternalTypeInfo;
37+
import org.apache.paimon.manifest.ManifestCommittable;
3538
import org.apache.paimon.table.FileStoreTable;
3639
import org.apache.paimon.table.SpecialFields;
3740
import org.apache.paimon.types.DataField;
@@ -41,16 +44,14 @@
4144
import org.apache.paimon.types.DataTypeRoot;
4245
import org.apache.paimon.types.RowType;
4346
import org.apache.paimon.utils.Preconditions;
44-
import org.apache.paimon.utils.StringUtils;
4547

4648
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
4749
import org.apache.flink.api.dag.Transformation;
4850
import org.apache.flink.api.java.tuple.Tuple2;
4951
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
5052
import org.apache.flink.streaming.api.datastream.DataStream;
5153
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
52-
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
53-
import org.apache.flink.streaming.api.operators.StreamMap;
54+
import org.apache.flink.streaming.api.operators.StreamFlatMap;
5455
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
5556
import org.apache.flink.table.api.Table;
5657
import org.apache.flink.table.api.TableResult;
@@ -69,10 +70,9 @@
6970
import java.util.Collections;
7071
import java.util.List;
7172
import java.util.Map;
73+
import java.util.Optional;
7274
import java.util.Set;
7375
import java.util.function.Function;
74-
import java.util.regex.Matcher;
75-
import java.util.regex.Pattern;
7676
import java.util.stream.Collectors;
7777

7878
import static org.apache.paimon.format.blob.BlobFileFormat.isBlobFile;
@@ -95,7 +95,6 @@
9595
public class DataEvolutionMergeIntoAction extends TableActionBase {
9696

9797
private static final Logger LOG = LoggerFactory.getLogger(DataEvolutionMergeIntoAction.class);
98-
public static final String IDENTIFIER_QUOTE = "`";
9998

10099
private final CoreOptions coreOptions;
101100

@@ -120,6 +119,7 @@ public class DataEvolutionMergeIntoAction extends TableActionBase {
120119

121120
// merge condition
122121
private String mergeCondition;
122+
private MergeConditionParser mergeConditionParser;
123123

124124
// set statement
125125
private String matchedUpdateSet;
@@ -137,6 +137,17 @@ public DataEvolutionMergeIntoAction(
137137
table.getClass().getName()));
138138
}
139139

140+
Long latestSnapshotId = ((FileStoreTable) table).snapshotManager().latestSnapshotId();
141+
if (latestSnapshotId == null) {
142+
throw new UnsupportedOperationException(
143+
"merge-into action doesn't support updating an empty table.");
144+
}
145+
table =
146+
table.copy(
147+
Collections.singletonMap(
148+
CoreOptions.COMMIT_STRICT_MODE_LAST_SAFE_SNAPSHOT.key(),
149+
latestSnapshotId.toString()));
150+
140151
this.coreOptions = ((FileStoreTable) table).coreOptions();
141152

142153
if (!coreOptions.dataEvolutionEnabled()) {
@@ -168,6 +179,12 @@ public DataEvolutionMergeIntoAction withTargetAlias(String targetAlias) {
168179

169180
public DataEvolutionMergeIntoAction withMergeCondition(String mergeCondition) {
170181
this.mergeCondition = mergeCondition;
182+
try {
183+
this.mergeConditionParser = new MergeConditionParser(mergeCondition);
184+
} catch (Exception e) {
185+
LOG.error("Failed to parse merge condition: {}", mergeCondition, e);
186+
throw new RuntimeException("Failed to parse merge condition " + mergeCondition, e);
187+
}
171188
return this;
172189
}
173190

@@ -196,7 +213,12 @@ public TableResult runInternal() {
196213
DataStream<Committable> written =
197214
writePartialColumns(shuffled, sourceWithType.f1, sinkParallelism);
198215
// 4. commit
199-
DataStream<?> committed = commit(written);
216+
Set<String> updatedColumns =
217+
sourceWithType.f1.getFields().stream()
218+
.map(DataField::name)
219+
.filter(name -> !SpecialFields.ROW_ID.name().equals(name))
220+
.collect(Collectors.toSet());
221+
DataStream<?> committed = commit(written, updatedColumns);
200222

201223
// execute internal
202224
Transformation<?> transformations =
@@ -219,8 +241,7 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
219241
List<String> project;
220242
if (matchedUpdateSet.equals("*")) {
221243
// if sourceName is qualified like 'default.S', we should build a project like S.*
222-
String[] splits = sourceTable.split("\\.");
223-
project = Collections.singletonList(splits[splits.length - 1] + ".*");
244+
project = Collections.singletonList(sourceTableName() + ".*");
224245
} else {
225246
// validate upsert changes
226247
Map<String, String> changes = parseCommaSeparatedKeyValues(matchedUpdateSet);
@@ -245,16 +266,38 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
245266
.collect(Collectors.toList());
246267
}
247268

248-
// use join to find matched rows and assign row id for each source row.
249-
// _ROW_ID is the first field of joined table.
250-
String query =
251-
String.format(
252-
"SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s",
253-
"`RT`.`_ROW_ID` as `_ROW_ID`",
254-
String.join(",", project),
255-
escapedSourceName(),
256-
escapedRowTrackingTargetName(),
257-
rewriteMergeCondition(mergeCondition));
269+
String query;
270+
Optional<String> sourceRowIdField;
271+
try {
272+
sourceRowIdField = mergeConditionParser.extractRowIdFieldFromSource(targetTableName());
273+
} catch (Exception e) {
274+
LOG.error("Error happened when extract row id field from source table.", e);
275+
throw new RuntimeException(
276+
"Error happened when extract row id field from source table.", e);
277+
}
278+
279+
// if source table already contains _ROW_ID field, we could avoid join
280+
if (sourceRowIdField.isPresent()) {
281+
query =
282+
String.format(
283+
// cast _ROW_ID to BIGINT
284+
"SELECT CAST(`%s`.`%s` AS BIGINT) AS `_ROW_ID`, %s FROM %s",
285+
sourceTableName(),
286+
sourceRowIdField.get(),
287+
String.join(",", project),
288+
escapedSourceName());
289+
} else {
290+
// use join to find matched rows and assign row id for each source row.
291+
// _ROW_ID is the first field of joined table.
292+
query =
293+
String.format(
294+
"SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s",
295+
"`RT`.`_ROW_ID` as `_ROW_ID`",
296+
String.join(",", project),
297+
escapedSourceName(),
298+
escapedRowTrackingTargetName(),
299+
rewriteMergeCondition(mergeCondition));
300+
}
258301

259302
LOG.info("Source query: {}", query);
260303

@@ -286,11 +329,15 @@ public DataStream<Tuple2<Long, RowData>> shuffleByFirstRowId(
286329
Preconditions.checkState(
287330
!firstRowIds.isEmpty(), "Should not MERGE INTO an empty target table.");
288331

332+
// if firstRowIds is not empty, there must be a valid nextRowId
333+
long maxRowId = table.latestSnapshot().get().nextRowId() - 1;
334+
289335
OneInputTransformation<RowData, Tuple2<Long, RowData>> assignedFirstRowId =
290336
new OneInputTransformation<>(
291337
sourceTransformation,
292338
"ASSIGN FIRST_ROW_ID",
293-
new StreamMap<>(new FirstRowIdAssigner(firstRowIds, sourceType)),
339+
new StreamFlatMap<>(
340+
new FirstRowIdAssigner(firstRowIds, maxRowId, sourceType)),
294341
new TupleTypeInfo<>(
295342
BasicTypeInfo.LONG_TYPE_INFO, sourceTransformation.getOutputType()),
296343
sourceTransformation.getParallelism(),
@@ -334,9 +381,20 @@ public DataStream<Committable> writePartialColumns(
334381
.setParallelism(sinkParallelism);
335382
}
336383

337-
public DataStream<Committable> commit(DataStream<Committable> written) {
384+
public DataStream<Committable> commit(
385+
DataStream<Committable> written, Set<String> updatedColumns) {
338386
FileStoreTable storeTable = (FileStoreTable) table;
339-
OneInputStreamOperatorFactory<Committable, Committable> committerOperator =
387+
388+
// Check if some global-indexed columns are updated
389+
DataStream<Committable> checked =
390+
written.transform(
391+
"Updated Column Check",
392+
new CommittableTypeInfo(),
393+
new MergeIntoUpdateChecker(storeTable, updatedColumns))
394+
.setParallelism(1)
395+
.setMaxParallelism(1);
396+
397+
CommitterOperatorFactory<Committable, ManifestCommittable> committerOperator =
340398
new CommitterOperatorFactory<>(
341399
false,
342400
true,
@@ -348,7 +406,7 @@ public DataStream<Committable> commit(DataStream<Committable> written) {
348406
context),
349407
new NoopCommittableStateManager());
350408

351-
return written.transform("COMMIT OPERATOR", new CommittableTypeInfo(), committerOperator)
409+
return checked.transform("COMMIT OPERATOR", new CommittableTypeInfo(), committerOperator)
352410
.setParallelism(1)
353411
.setMaxParallelism(1);
354412
}
@@ -382,28 +440,13 @@ private DataStream<RowData> toDataStream(Table source) {
382440
*/
383441
@VisibleForTesting
384442
public String rewriteMergeCondition(String mergeCondition) {
385-
// skip single and double-quoted chunks
386-
String skipQuoted = "'(?:''|[^'])*'" + "|\"(?:\"\"|[^\"])*\"";
387-
String targetTableRegex =
388-
"(?i)(?:\\b"
389-
+ Pattern.quote(targetTableName())
390-
+ "\\b|`"
391-
+ Pattern.quote(targetTableName())
392-
+ "`)\\s*\\.";
393-
394-
Pattern pattern = Pattern.compile(skipQuoted + "|(" + targetTableRegex + ")");
395-
Matcher matcher = pattern.matcher(mergeCondition);
396-
397-
StringBuffer sb = new StringBuffer();
398-
while (matcher.find()) {
399-
if (matcher.group(1) != null) {
400-
matcher.appendReplacement(sb, Matcher.quoteReplacement("`RT`."));
401-
} else {
402-
matcher.appendReplacement(sb, Matcher.quoteReplacement(matcher.group(0)));
403-
}
443+
try {
444+
Object rewrittenNode = mergeConditionParser.rewriteSqlNode(targetTableName(), "RT");
445+
return rewrittenNode.toString();
446+
} catch (Exception e) {
447+
LOG.error("Failed to rewrite merge condition: {}", mergeCondition, e);
448+
throw new RuntimeException("Failed to rewrite merge condition " + mergeCondition, e);
404449
}
405-
matcher.appendTail(sb);
406-
return sb.toString();
407450
}
408451

409452
/**
@@ -432,7 +475,8 @@ private void checkSchema(Table source) {
432475
foundRowIdColumn = true;
433476
Preconditions.checkState(
434477
flinkColumn.getDataType().getLogicalType().getTypeRoot()
435-
== LogicalTypeRoot.BIGINT);
478+
== LogicalTypeRoot.BIGINT,
479+
"_ROW_ID field should be BIGINT type.");
436480
} else {
437481
DataField targetField = targetFields.get(flinkColumn.getName());
438482
if (targetField == null) {
@@ -497,6 +541,11 @@ private String targetTableName() {
497541
return targetAlias == null ? identifier.getObjectName() : targetAlias;
498542
}
499543

544+
private String sourceTableName() {
545+
String[] splits = sourceTable.split("\\.");
546+
return splits[splits.length - 1];
547+
}
548+
500549
private String escapedSourceName() {
501550
return Arrays.stream(sourceTable.split("\\."))
502551
.map(s -> String.format("`%s`", s))
@@ -514,28 +563,108 @@ private String escapedRowTrackingTargetName() {
514563
catalogName, identifier.getDatabaseName(), identifier.getObjectName());
515564
}
516565

517-
private List<String> normalizeFieldName(List<String> fieldNames) {
518-
return fieldNames.stream().map(this::normalizeFieldName).collect(Collectors.toList());
519-
}
566+
/** The parser to parse merge condition through calcite sql parser. */
567+
static class MergeConditionParser {
568+
569+
private final FlinkCalciteClasses calciteClasses;
570+
private final Object sqlNode;
571+
572+
MergeConditionParser(String mergeCondition) throws Exception {
573+
this.calciteClasses = new FlinkCalciteClasses();
574+
this.sqlNode = initializeSqlNode(mergeCondition);
575+
}
520576

521-
private String normalizeFieldName(String fieldName) {
522-
if (StringUtils.isNullOrWhitespaceOnly(fieldName) || fieldName.endsWith(IDENTIFIER_QUOTE)) {
523-
return fieldName;
577+
private Object initializeSqlNode(String mergeCondition) throws Exception {
578+
Object config =
579+
calciteClasses
580+
.configDelegate()
581+
.withLex(
582+
calciteClasses.sqlParserDelegate().config(),
583+
calciteClasses.lexDelegate().java());
584+
Object sqlParser = calciteClasses.sqlParserDelegate().create(mergeCondition, config);
585+
return calciteClasses.sqlParserDelegate().parseExpression(sqlParser);
524586
}
525587

526-
String[] splitFieldNames = fieldName.split("\\.");
527-
if (!targetFieldNames.contains(splitFieldNames[splitFieldNames.length - 1])) {
528-
return fieldName;
588+
/**
589+
* Rewrite the SQL node, replacing all references from the 'from' table to the 'to' table.
590+
*/
591+
public Object rewriteSqlNode(String from, String to) throws Exception {
592+
return rewriteNode(sqlNode, from, to);
529593
}
530594

531-
return String.join(
532-
".",
533-
Arrays.stream(splitFieldNames)
534-
.map(
535-
part ->
536-
part.endsWith(IDENTIFIER_QUOTE)
537-
? part
538-
: IDENTIFIER_QUOTE + part + IDENTIFIER_QUOTE)
539-
.toArray(String[]::new));
595+
private Object rewriteNode(Object node, String from, String to) throws Exception {
596+
// It's a SqlBasicCall, recursively rewrite children operands
597+
if (calciteClasses.sqlBasicCallDelegate().instanceOfSqlBasicCall(node)) {
598+
List<?> operandList = calciteClasses.sqlBasicCallDelegate().getOperandList(node);
599+
List<Object> newNodes = new java.util.ArrayList<>();
600+
for (Object operand : operandList) {
601+
newNodes.add(rewriteNode(operand, from, to));
602+
}
603+
604+
Object operator = calciteClasses.sqlBasicCallDelegate().getOperator(node);
605+
Object parserPos = calciteClasses.sqlBasicCallDelegate().getParserPosition(node);
606+
Object functionQuantifier =
607+
calciteClasses.sqlBasicCallDelegate().getFunctionQuantifier(node);
608+
return calciteClasses
609+
.sqlBasicCallDelegate()
610+
.create(operator, newNodes, parserPos, functionQuantifier);
611+
} else if (calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(node)) {
612+
// It's a sql identifier, try to replace the table name
613+
List<String> names = calciteClasses.sqlIndentifierDelegate().getNames(node);
614+
Preconditions.checkState(
615+
names.size() >= 2, "Please specify the table name for the column: " + node);
616+
int nameLen = names.size();
617+
if (names.get(nameLen - 2).equals(from)) {
618+
return calciteClasses.sqlIndentifierDelegate().setName(node, nameLen - 2, to);
619+
}
620+
return node;
621+
} else {
622+
return node;
623+
}
624+
}
625+
626+
/**
627+
* Find the row id field in source table. This method looks for an equality condition like
628+
* `target_table._ROW_ID = source_table.some_field` or `source_table.some_field =
629+
* target_table._ROW_ID`, and returns the field name that is paired with _ROW_ID.
630+
*/
631+
public Optional<String> extractRowIdFieldFromSource(String targetTable) throws Exception {
632+
Object operator = calciteClasses.sqlBasicCallDelegate().getOperator(sqlNode);
633+
Object kind = calciteClasses.sqlOperatorDelegate().getKind(operator);
634+
635+
if (kind == calciteClasses.sqlKindDelegate().equals()) {
636+
List<?> operandList = calciteClasses.sqlBasicCallDelegate().getOperandList(sqlNode);
637+
638+
Object left = operandList.get(0);
639+
Object right = operandList.get(1);
640+
641+
if (calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(left)
642+
&& calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(right)) {
643+
644+
List<String> leftNames = calciteClasses.sqlIndentifierDelegate().getNames(left);
645+
List<String> rightNames =
646+
calciteClasses.sqlIndentifierDelegate().getNames(right);
647+
Preconditions.checkState(
648+
leftNames.size() >= 2,
649+
"Please specify the table name for the column: " + left);
650+
Preconditions.checkState(
651+
rightNames.size() >= 2,
652+
"Please specify the table name for the column: " + right);
653+
654+
if (leftNames.get(leftNames.size() - 1).equals(SpecialFields.ROW_ID.name())
655+
&& leftNames.get(leftNames.size() - 2).equals(targetTable)) {
656+
return Optional.of(rightNames.get(rightNames.size() - 1));
657+
} else if (rightNames
658+
.get(rightNames.size() - 1)
659+
.equals(SpecialFields.ROW_ID.name())
660+
&& rightNames.get(rightNames.size() - 2).equals(targetTable)) {
661+
return Optional.of(leftNames.get(leftNames.size() - 1));
662+
}
663+
return Optional.empty();
664+
}
665+
}
666+
667+
return Optional.empty();
668+
}
540669
}
541670
}

0 commit comments

Comments
 (0)