@@ -22,12 +22,11 @@ import org.apache.paimon.Snapshot.CommitKind
2222import org .apache .paimon .spark .PaimonSparkTestBase
2323
2424import org .apache .spark .sql .Row
25- import org .apache .spark .sql .catalyst .plans .logical .{Join , LogicalPlan , RepartitionByExpression , Sort }
25+ import org .apache .spark .sql .catalyst .plans .logical .{Deduplicate , Join , LogicalPlan , MergeRows , RepartitionByExpression , Sort }
2626import org .apache .spark .sql .execution .QueryExecution
2727import org .apache .spark .sql .util .QueryExecutionListener
2828
29- import scala .collection .JavaConverters ._
30- import scala .collection .mutable
29+ import java .util .concurrent .{CountDownLatch , TimeUnit }
3130
3231abstract class RowTrackingTestBase extends PaimonSparkTestBase {
3332
@@ -397,13 +396,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
397396 sql(
398397 " INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')" )
399398
400- val capturedPlans : mutable.ListBuffer [LogicalPlan ] = mutable.ListBuffer .empty
399+ var findSplitsPlan : LogicalPlan = null
400+ val latch = new CountDownLatch (1 )
401401 val listener = new QueryExecutionListener {
402402 override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
403- capturedPlans += qe.analyzed
403+ if (qe.analyzed.collectFirst { case _ : Deduplicate => true }.nonEmpty) {
404+ latch.countDown()
405+ findSplitsPlan = qe.analyzed
406+ }
404407 }
405408 override def onFailure (funcName : String , qe : QueryExecution , exception : Exception ): Unit = {
406- capturedPlans += qe.analyzed
409+ if (qe.analyzed.collectFirst { case _ : Deduplicate => true }.nonEmpty) {
410+ latch.countDown()
411+ findSplitsPlan = qe.analyzed
412+ }
407413 }
408414 }
409415 spark.listenerManager.register(listener)
@@ -416,9 +422,10 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
416422 |WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES (target_ROW_ID, b * 1.1, c)
417423 |WHEN NOT MATCHED THEN INSERT (a, b, c) VALUES (target_ROW_ID, b, c)
418424 | """ .stripMargin)
425+ assert(latch.await(10 , TimeUnit .SECONDS ), " await timeout" )
419426 // Assert that no Join operator was used during
420427 // `org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.targetRelatedSplits`
421- assert(capturedPlans.head .collect { case plan : Join => plan }.isEmpty)
428+ assert(findSplitsPlan != null && findSplitsPlan .collect { case plan : Join => plan }.isEmpty)
422429 spark.listenerManager.unregister(listener)
423430
424431 checkAnswer(
@@ -442,13 +449,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
442449 sql(
443450 " INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')" )
444451
445- val capturedPlans = new java.util.concurrent.CopyOnWriteArrayList [LogicalPlan ]()
452+ var updatePlan : LogicalPlan = null
453+ val latch = new CountDownLatch (1 )
446454 val listener = new QueryExecutionListener {
447455 override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
448- capturedPlans.add(qe.analyzed)
456+ if (qe.analyzed.collectFirst { case _ : MergeRows => true }.nonEmpty) {
457+ latch.countDown()
458+ updatePlan = qe.analyzed
459+ }
449460 }
450461 override def onFailure (funcName : String , qe : QueryExecution , exception : Exception ): Unit = {
451- capturedPlans.add(qe.analyzed)
462+ if (qe.analyzed.collectFirst { case _ : MergeRows => true }.nonEmpty) {
463+ latch.countDown()
464+ updatePlan = qe.analyzed
465+ }
452466 }
453467 }
454468 spark.listenerManager.register(listener)
@@ -460,17 +474,17 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
460474 |WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b * 3,
461475 |c = concat(target.c, source.c)
462476 | """ .stripMargin).collect()
477+ assert(latch.await(10 , TimeUnit .SECONDS ), " await timeout" )
463478 // Assert no shuffle/join/sort was used in
464479 // 'org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.updateActionInvoke'
465480 assert(
466- capturedPlans.asScala.forall(
467- plan =>
468- plan.collectFirst {
469- case p : Join => p
470- case p : Sort => p
471- case p : RepartitionByExpression => p
472- }.isEmpty),
473- s " Found unexpected Join/Sort/Exchange in plan: \n $capturedPlans"
481+ updatePlan != null &&
482+ updatePlan.collectFirst {
483+ case p : Join => p
484+ case p : Sort => p
485+ case p : RepartitionByExpression => p
486+ }.isEmpty,
487+ s " Found unexpected Join/Sort/Exchange in plan: $updatePlan"
474488 )
475489 spark.listenerManager.unregister(listener)
476490
0 commit comments