2525import org .apache .paimon .flink .LogicalTypeConversion ;
2626import org .apache .paimon .flink .dataevolution .DataEvolutionPartialWriteOperator ;
2727import org .apache .paimon .flink .dataevolution .FirstRowIdAssigner ;
28+ import org .apache .paimon .flink .dataevolution .MergeIntoUpdateChecker ;
2829import org .apache .paimon .flink .sink .Committable ;
2930import org .apache .paimon .flink .sink .CommittableTypeInfo ;
3031import org .apache .paimon .flink .sink .CommitterOperatorFactory ;
3132import org .apache .paimon .flink .sink .NoopCommittableStateManager ;
3233import org .apache .paimon .flink .sink .StoreCommitter ;
3334import org .apache .paimon .flink .sorter .SortOperator ;
35+ import org .apache .paimon .flink .utils .FlinkCalciteClasses ;
3436import org .apache .paimon .flink .utils .InternalTypeInfo ;
37+ import org .apache .paimon .manifest .ManifestCommittable ;
3538import org .apache .paimon .table .FileStoreTable ;
3639import org .apache .paimon .table .SpecialFields ;
3740import org .apache .paimon .types .DataField ;
4144import org .apache .paimon .types .DataTypeRoot ;
4245import org .apache .paimon .types .RowType ;
4346import org .apache .paimon .utils .Preconditions ;
44- import org .apache .paimon .utils .StringUtils ;
4547
4648import org .apache .flink .api .common .typeinfo .BasicTypeInfo ;
4749import org .apache .flink .api .dag .Transformation ;
4850import org .apache .flink .api .java .tuple .Tuple2 ;
4951import org .apache .flink .api .java .typeutils .TupleTypeInfo ;
5052import org .apache .flink .streaming .api .datastream .DataStream ;
5153import org .apache .flink .streaming .api .functions .sink .v2 .DiscardingSink ;
52- import org .apache .flink .streaming .api .operators .OneInputStreamOperatorFactory ;
53- import org .apache .flink .streaming .api .operators .StreamMap ;
54+ import org .apache .flink .streaming .api .operators .StreamFlatMap ;
5455import org .apache .flink .streaming .api .transformations .OneInputTransformation ;
5556import org .apache .flink .table .api .Table ;
5657import org .apache .flink .table .api .TableResult ;
6970import java .util .Collections ;
7071import java .util .List ;
7172import java .util .Map ;
73+ import java .util .Optional ;
7274import java .util .Set ;
7375import java .util .function .Function ;
74- import java .util .regex .Matcher ;
75- import java .util .regex .Pattern ;
7676import java .util .stream .Collectors ;
7777
7878import static org .apache .paimon .format .blob .BlobFileFormat .isBlobFile ;
9595public class DataEvolutionMergeIntoAction extends TableActionBase {
9696
9797 private static final Logger LOG = LoggerFactory .getLogger (DataEvolutionMergeIntoAction .class );
98- public static final String IDENTIFIER_QUOTE = "`" ;
9998
10099 private final CoreOptions coreOptions ;
101100
@@ -120,6 +119,7 @@ public class DataEvolutionMergeIntoAction extends TableActionBase {
120119
121120 // merge condition
122121 private String mergeCondition ;
122+ private MergeConditionParser mergeConditionParser ;
123123
124124 // set statement
125125 private String matchedUpdateSet ;
@@ -137,6 +137,17 @@ public DataEvolutionMergeIntoAction(
137137 table .getClass ().getName ()));
138138 }
139139
140+ Long latestSnapshotId = ((FileStoreTable ) table ).snapshotManager ().latestSnapshotId ();
141+ if (latestSnapshotId == null ) {
142+ throw new UnsupportedOperationException (
143+ "merge-into action doesn't support updating an empty table." );
144+ }
145+ table =
146+ table .copy (
147+ Collections .singletonMap (
148+ CoreOptions .COMMIT_STRICT_MODE_LAST_SAFE_SNAPSHOT .key (),
149+ latestSnapshotId .toString ()));
150+
140151 this .coreOptions = ((FileStoreTable ) table ).coreOptions ();
141152
142153 if (!coreOptions .dataEvolutionEnabled ()) {
@@ -168,6 +179,12 @@ public DataEvolutionMergeIntoAction withTargetAlias(String targetAlias) {
168179
169180 public DataEvolutionMergeIntoAction withMergeCondition (String mergeCondition ) {
170181 this .mergeCondition = mergeCondition ;
182+ try {
183+ this .mergeConditionParser = new MergeConditionParser (mergeCondition );
184+ } catch (Exception e ) {
185+ LOG .error ("Failed to parse merge condition: {}" , mergeCondition , e );
186+ throw new RuntimeException ("Failed to parse merge condition " + mergeCondition , e );
187+ }
171188 return this ;
172189 }
173190
@@ -196,7 +213,12 @@ public TableResult runInternal() {
196213 DataStream <Committable > written =
197214 writePartialColumns (shuffled , sourceWithType .f1 , sinkParallelism );
198215 // 4. commit
199- DataStream <?> committed = commit (written );
216+ Set <String > updatedColumns =
217+ sourceWithType .f1 .getFields ().stream ()
218+ .map (DataField ::name )
219+ .filter (name -> !SpecialFields .ROW_ID .name ().equals (name ))
220+ .collect (Collectors .toSet ());
221+ DataStream <?> committed = commit (written , updatedColumns );
200222
201223 // execute internal
202224 Transformation <?> transformations =
@@ -219,8 +241,7 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
219241 List <String > project ;
220242 if (matchedUpdateSet .equals ("*" )) {
221243 // if sourceName is qualified like 'default.S', we should build a project like S.*
222- String [] splits = sourceTable .split ("\\ ." );
223- project = Collections .singletonList (splits [splits .length - 1 ] + ".*" );
244+ project = Collections .singletonList (sourceTableName () + ".*" );
224245 } else {
225246 // validate upsert changes
226247 Map <String , String > changes = parseCommaSeparatedKeyValues (matchedUpdateSet );
@@ -245,16 +266,38 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
245266 .collect (Collectors .toList ());
246267 }
247268
248- // use join to find matched rows and assign row id for each source row.
249- // _ROW_ID is the first field of joined table.
250- String query =
251- String .format (
252- "SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s" ,
253- "`RT`.`_ROW_ID` as `_ROW_ID`" ,
254- String .join ("," , project ),
255- escapedSourceName (),
256- escapedRowTrackingTargetName (),
257- rewriteMergeCondition (mergeCondition ));
269+ String query ;
270+ Optional <String > sourceRowIdField ;
271+ try {
272+ sourceRowIdField = mergeConditionParser .extractRowIdFieldFromSource (targetTableName ());
273+ } catch (Exception e ) {
274+ LOG .error ("Error happened when extract row id field from source table." , e );
275+ throw new RuntimeException (
276+ "Error happened when extract row id field from source table." , e );
277+ }
278+
279+ // if source table already contains _ROW_ID field, we could avoid join
280+ if (sourceRowIdField .isPresent ()) {
281+ query =
282+ String .format (
283+ // cast _ROW_ID to BIGINT
284+ "SELECT CAST(`%s`.`%s` AS BIGINT) AS `_ROW_ID`, %s FROM %s" ,
285+ sourceTableName (),
286+ sourceRowIdField .get (),
287+ String .join ("," , project ),
288+ escapedSourceName ());
289+ } else {
290+ // use join to find matched rows and assign row id for each source row.
291+ // _ROW_ID is the first field of joined table.
292+ query =
293+ String .format (
294+ "SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s" ,
295+ "`RT`.`_ROW_ID` as `_ROW_ID`" ,
296+ String .join ("," , project ),
297+ escapedSourceName (),
298+ escapedRowTrackingTargetName (),
299+ rewriteMergeCondition (mergeCondition ));
300+ }
258301
259302 LOG .info ("Source query: {}" , query );
260303
@@ -286,11 +329,15 @@ public DataStream<Tuple2<Long, RowData>> shuffleByFirstRowId(
286329 Preconditions .checkState (
287330 !firstRowIds .isEmpty (), "Should not MERGE INTO an empty target table." );
288331
332+ // if firstRowIds is not empty, there must be a valid nextRowId
333+ long maxRowId = table .latestSnapshot ().get ().nextRowId () - 1 ;
334+
289335 OneInputTransformation <RowData , Tuple2 <Long , RowData >> assignedFirstRowId =
290336 new OneInputTransformation <>(
291337 sourceTransformation ,
292338 "ASSIGN FIRST_ROW_ID" ,
293- new StreamMap <>(new FirstRowIdAssigner (firstRowIds , sourceType )),
339+ new StreamFlatMap <>(
340+ new FirstRowIdAssigner (firstRowIds , maxRowId , sourceType )),
294341 new TupleTypeInfo <>(
295342 BasicTypeInfo .LONG_TYPE_INFO , sourceTransformation .getOutputType ()),
296343 sourceTransformation .getParallelism (),
@@ -334,9 +381,20 @@ public DataStream<Committable> writePartialColumns(
334381 .setParallelism (sinkParallelism );
335382 }
336383
337- public DataStream <Committable > commit (DataStream <Committable > written ) {
384+ public DataStream <Committable > commit (
385+ DataStream <Committable > written , Set <String > updatedColumns ) {
338386 FileStoreTable storeTable = (FileStoreTable ) table ;
339- OneInputStreamOperatorFactory <Committable , Committable > committerOperator =
387+
388+ // Check if some global-indexed columns are updated
389+ DataStream <Committable > checked =
390+ written .transform (
391+ "Updated Column Check" ,
392+ new CommittableTypeInfo (),
393+ new MergeIntoUpdateChecker (storeTable , updatedColumns ))
394+ .setParallelism (1 )
395+ .setMaxParallelism (1 );
396+
397+ CommitterOperatorFactory <Committable , ManifestCommittable > committerOperator =
340398 new CommitterOperatorFactory <>(
341399 false ,
342400 true ,
@@ -348,7 +406,7 @@ public DataStream<Committable> commit(DataStream<Committable> written) {
348406 context ),
349407 new NoopCommittableStateManager ());
350408
351- return written .transform ("COMMIT OPERATOR" , new CommittableTypeInfo (), committerOperator )
409+ return checked .transform ("COMMIT OPERATOR" , new CommittableTypeInfo (), committerOperator )
352410 .setParallelism (1 )
353411 .setMaxParallelism (1 );
354412 }
@@ -382,28 +440,13 @@ private DataStream<RowData> toDataStream(Table source) {
382440 */
383441 @ VisibleForTesting
384442 public String rewriteMergeCondition (String mergeCondition ) {
385- // skip single and double-quoted chunks
386- String skipQuoted = "'(?:''|[^'])*'" + "|\" (?:\" \" |[^\" ])*\" " ;
387- String targetTableRegex =
388- "(?i)(?:\\ b"
389- + Pattern .quote (targetTableName ())
390- + "\\ b|`"
391- + Pattern .quote (targetTableName ())
392- + "`)\\ s*\\ ." ;
393-
394- Pattern pattern = Pattern .compile (skipQuoted + "|(" + targetTableRegex + ")" );
395- Matcher matcher = pattern .matcher (mergeCondition );
396-
397- StringBuffer sb = new StringBuffer ();
398- while (matcher .find ()) {
399- if (matcher .group (1 ) != null ) {
400- matcher .appendReplacement (sb , Matcher .quoteReplacement ("`RT`." ));
401- } else {
402- matcher .appendReplacement (sb , Matcher .quoteReplacement (matcher .group (0 )));
403- }
443+ try {
444+ Object rewrittenNode = mergeConditionParser .rewriteSqlNode (targetTableName (), "RT" );
445+ return rewrittenNode .toString ();
446+ } catch (Exception e ) {
447+ LOG .error ("Failed to rewrite merge condition: {}" , mergeCondition , e );
448+ throw new RuntimeException ("Failed to rewrite merge condition " + mergeCondition , e );
404449 }
405- matcher .appendTail (sb );
406- return sb .toString ();
407450 }
408451
409452 /**
@@ -432,7 +475,8 @@ private void checkSchema(Table source) {
432475 foundRowIdColumn = true ;
433476 Preconditions .checkState (
434477 flinkColumn .getDataType ().getLogicalType ().getTypeRoot ()
435- == LogicalTypeRoot .BIGINT );
478+ == LogicalTypeRoot .BIGINT ,
479+ "_ROW_ID field should be BIGINT type." );
436480 } else {
437481 DataField targetField = targetFields .get (flinkColumn .getName ());
438482 if (targetField == null ) {
@@ -497,6 +541,11 @@ private String targetTableName() {
497541 return targetAlias == null ? identifier .getObjectName () : targetAlias ;
498542 }
499543
544+ private String sourceTableName () {
545+ String [] splits = sourceTable .split ("\\ ." );
546+ return splits [splits .length - 1 ];
547+ }
548+
500549 private String escapedSourceName () {
501550 return Arrays .stream (sourceTable .split ("\\ ." ))
502551 .map (s -> String .format ("`%s`" , s ))
@@ -514,28 +563,108 @@ private String escapedRowTrackingTargetName() {
514563 catalogName , identifier .getDatabaseName (), identifier .getObjectName ());
515564 }
516565
517- private List <String > normalizeFieldName (List <String > fieldNames ) {
518- return fieldNames .stream ().map (this ::normalizeFieldName ).collect (Collectors .toList ());
519- }
566+ /** The parser to parse merge condition through calcite sql parser. */
567+ static class MergeConditionParser {
568+
569+ private final FlinkCalciteClasses calciteClasses ;
570+ private final Object sqlNode ;
571+
572+ MergeConditionParser (String mergeCondition ) throws Exception {
573+ this .calciteClasses = new FlinkCalciteClasses ();
574+ this .sqlNode = initializeSqlNode (mergeCondition );
575+ }
520576
521- private String normalizeFieldName (String fieldName ) {
522- if (StringUtils .isNullOrWhitespaceOnly (fieldName ) || fieldName .endsWith (IDENTIFIER_QUOTE )) {
523- return fieldName ;
577+ private Object initializeSqlNode (String mergeCondition ) throws Exception {
578+ Object config =
579+ calciteClasses
580+ .configDelegate ()
581+ .withLex (
582+ calciteClasses .sqlParserDelegate ().config (),
583+ calciteClasses .lexDelegate ().java ());
584+ Object sqlParser = calciteClasses .sqlParserDelegate ().create (mergeCondition , config );
585+ return calciteClasses .sqlParserDelegate ().parseExpression (sqlParser );
524586 }
525587
526- String [] splitFieldNames = fieldName .split ("\\ ." );
527- if (!targetFieldNames .contains (splitFieldNames [splitFieldNames .length - 1 ])) {
528- return fieldName ;
588+ /**
589+ * Rewrite the SQL node, replacing all references from the 'from' table to the 'to' table.
590+ */
591+ public Object rewriteSqlNode (String from , String to ) throws Exception {
592+ return rewriteNode (sqlNode , from , to );
529593 }
530594
531- return String .join (
532- "." ,
533- Arrays .stream (splitFieldNames )
534- .map (
535- part ->
536- part .endsWith (IDENTIFIER_QUOTE )
537- ? part
538- : IDENTIFIER_QUOTE + part + IDENTIFIER_QUOTE )
539- .toArray (String []::new ));
595+ private Object rewriteNode (Object node , String from , String to ) throws Exception {
596+ // It's a SqlBasicCall, recursively rewrite children operands
597+ if (calciteClasses .sqlBasicCallDelegate ().instanceOfSqlBasicCall (node )) {
598+ List <?> operandList = calciteClasses .sqlBasicCallDelegate ().getOperandList (node );
599+ List <Object > newNodes = new java .util .ArrayList <>();
600+ for (Object operand : operandList ) {
601+ newNodes .add (rewriteNode (operand , from , to ));
602+ }
603+
604+ Object operator = calciteClasses .sqlBasicCallDelegate ().getOperator (node );
605+ Object parserPos = calciteClasses .sqlBasicCallDelegate ().getParserPosition (node );
606+ Object functionQuantifier =
607+ calciteClasses .sqlBasicCallDelegate ().getFunctionQuantifier (node );
608+ return calciteClasses
609+ .sqlBasicCallDelegate ()
610+ .create (operator , newNodes , parserPos , functionQuantifier );
611+ } else if (calciteClasses .sqlIndentifierDelegate ().instanceOfSqlIdentifier (node )) {
612+ // It's a sql identifier, try to replace the table name
613+ List <String > names = calciteClasses .sqlIndentifierDelegate ().getNames (node );
614+ Preconditions .checkState (
615+ names .size () >= 2 , "Please specify the table name for the column: " + node );
616+ int nameLen = names .size ();
617+ if (names .get (nameLen - 2 ).equals (from )) {
618+ return calciteClasses .sqlIndentifierDelegate ().setName (node , nameLen - 2 , to );
619+ }
620+ return node ;
621+ } else {
622+ return node ;
623+ }
624+ }
625+
626+ /**
627+ * Find the row id field in source table. This method looks for an equality condition like
628+ * `target_table._ROW_ID = source_table.some_field` or `source_table.some_field =
629+ * target_table._ROW_ID`, and returns the field name that is paired with _ROW_ID.
630+ */
631+ public Optional <String > extractRowIdFieldFromSource (String targetTable ) throws Exception {
632+ Object operator = calciteClasses .sqlBasicCallDelegate ().getOperator (sqlNode );
633+ Object kind = calciteClasses .sqlOperatorDelegate ().getKind (operator );
634+
635+ if (kind == calciteClasses .sqlKindDelegate ().equals ()) {
636+ List <?> operandList = calciteClasses .sqlBasicCallDelegate ().getOperandList (sqlNode );
637+
638+ Object left = operandList .get (0 );
639+ Object right = operandList .get (1 );
640+
641+ if (calciteClasses .sqlIndentifierDelegate ().instanceOfSqlIdentifier (left )
642+ && calciteClasses .sqlIndentifierDelegate ().instanceOfSqlIdentifier (right )) {
643+
644+ List <String > leftNames = calciteClasses .sqlIndentifierDelegate ().getNames (left );
645+ List <String > rightNames =
646+ calciteClasses .sqlIndentifierDelegate ().getNames (right );
647+ Preconditions .checkState (
648+ leftNames .size () >= 2 ,
649+ "Please specify the table name for the column: " + left );
650+ Preconditions .checkState (
651+ rightNames .size () >= 2 ,
652+ "Please specify the table name for the column: " + right );
653+
654+ if (leftNames .get (leftNames .size () - 1 ).equals (SpecialFields .ROW_ID .name ())
655+ && leftNames .get (leftNames .size () - 2 ).equals (targetTable )) {
656+ return Optional .of (rightNames .get (rightNames .size () - 1 ));
657+ } else if (rightNames
658+ .get (rightNames .size () - 1 )
659+ .equals (SpecialFields .ROW_ID .name ())
660+ && rightNames .get (rightNames .size () - 2 ).equals (targetTable )) {
661+ return Optional .of (leftNames .get (leftNames .size () - 1 ));
662+ }
663+ return Optional .empty ();
664+ }
665+ }
666+
667+ return Optional .empty ();
668+ }
540669 }
541670}
0 commit comments