Skip to content

Commit 55e4d28

Browse files
authored
[spark] fix Merge Into unstable tests (#6912)
1 parent f15bcfe commit 55e4d28

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@ import org.apache.paimon.Snapshot.CommitKind
2222
import org.apache.paimon.spark.PaimonSparkTestBase
2323

2424
import org.apache.spark.sql.Row
25-
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, RepartitionByExpression, Sort}
25+
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Join, LogicalPlan, MergeRows, RepartitionByExpression, Sort}
2626
import org.apache.spark.sql.execution.QueryExecution
2727
import org.apache.spark.sql.util.QueryExecutionListener
2828

29-
import scala.collection.JavaConverters._
30-
import scala.collection.mutable
29+
import java.util.concurrent.{CountDownLatch, TimeUnit}
3130

3231
abstract class RowTrackingTestBase extends PaimonSparkTestBase {
3332

@@ -397,13 +396,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
397396
sql(
398397
"INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
399398

400-
val capturedPlans: mutable.ListBuffer[LogicalPlan] = mutable.ListBuffer.empty
399+
var findSplitsPlan: LogicalPlan = null
400+
val latch = new CountDownLatch(1)
401401
val listener = new QueryExecutionListener {
402402
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
403-
capturedPlans += qe.analyzed
403+
if (qe.analyzed.collectFirst { case _: Deduplicate => true }.nonEmpty) {
404+
latch.countDown()
405+
findSplitsPlan = qe.analyzed
406+
}
404407
}
405408
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
406-
capturedPlans += qe.analyzed
409+
if (qe.analyzed.collectFirst { case _: Deduplicate => true }.nonEmpty) {
410+
latch.countDown()
411+
findSplitsPlan = qe.analyzed
412+
}
407413
}
408414
}
409415
spark.listenerManager.register(listener)
@@ -416,9 +422,10 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
416422
|WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES (target_ROW_ID, b * 1.1, c)
417423
|WHEN NOT MATCHED THEN INSERT (a, b, c) VALUES (target_ROW_ID, b, c)
418424
|""".stripMargin)
425+
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
419426
// Assert that no Join operator was used during
420427
// `org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.targetRelatedSplits`
421-
assert(capturedPlans.head.collect { case plan: Join => plan }.isEmpty)
428+
assert(findSplitsPlan != null && findSplitsPlan.collect { case plan: Join => plan }.isEmpty)
422429
spark.listenerManager.unregister(listener)
423430

424431
checkAnswer(
@@ -442,13 +449,20 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
442449
sql(
443450
"INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
444451

445-
val capturedPlans = new java.util.concurrent.CopyOnWriteArrayList[LogicalPlan]()
452+
var updatePlan: LogicalPlan = null
453+
val latch = new CountDownLatch(1)
446454
val listener = new QueryExecutionListener {
447455
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
448-
capturedPlans.add(qe.analyzed)
456+
if (qe.analyzed.collectFirst { case _: MergeRows => true }.nonEmpty) {
457+
latch.countDown()
458+
updatePlan = qe.analyzed
459+
}
449460
}
450461
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
451-
capturedPlans.add(qe.analyzed)
462+
if (qe.analyzed.collectFirst { case _: MergeRows => true }.nonEmpty) {
463+
latch.countDown()
464+
updatePlan = qe.analyzed
465+
}
452466
}
453467
}
454468
spark.listenerManager.register(listener)
@@ -460,17 +474,17 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
460474
|WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b * 3,
461475
|c = concat(target.c, source.c)
462476
|""".stripMargin).collect()
477+
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
463478
// Assert no shuffle/join/sort was used in
464479
// 'org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.updateActionInvoke'
465480
assert(
466-
capturedPlans.asScala.forall(
467-
plan =>
468-
plan.collectFirst {
469-
case p: Join => p
470-
case p: Sort => p
471-
case p: RepartitionByExpression => p
472-
}.isEmpty),
473-
s"Found unexpected Join/Sort/Exchange in plan:\n$capturedPlans"
481+
updatePlan != null &&
482+
updatePlan.collectFirst {
483+
case p: Join => p
484+
case p: Sort => p
485+
case p: RepartitionByExpression => p
486+
}.isEmpty,
487+
s"Found unexpected Join/Sort/Exchange in plan: $updatePlan"
474488
)
475489
spark.listenerManager.unregister(listener)
476490

0 commit comments

Comments
 (0)