@@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {
5454
5555 private def writeWithCometNativeWriteExec (
5656 inputPath : String ,
57- outputPath : String ): Option [QueryExecution ] = {
57+ outputPath : String ,
58+ num_partitions : Option [Int ] = None ): Option [QueryExecution ] = {
5859 val df = spark.read.parquet(inputPath)
5960
6061 // Use a listener to capture the execution plan during write
@@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
7778 spark.listenerManager.register(listener)
7879
7980 try {
80- // Perform native write
81- df .write.parquet(outputPath)
81+ // Perform native write with optional partitioning
82+ num_partitions.fold(df)(n => df.repartition(n)) .write.parquet(outputPath)
8283
8384 // Wait for listener to be called with timeout
8485 val maxWaitTimeMs = 15000
@@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase {
9798 s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
9899
99100 capturedPlan.foreach { qe =>
100- val executedPlan = qe.executedPlan
101- val hasNativeWrite = executedPlan.exists {
102- case _ : CometNativeWriteExec => true
101+ val executedPlan = stripAQEPlan(qe.executedPlan)
102+
103+ // Count CometNativeWriteExec instances in the plan
104+ var nativeWriteCount = 0
105+ executedPlan.foreach {
106+ case _ : CometNativeWriteExec =>
107+ nativeWriteCount += 1
103108 case d : DataWritingCommandExec =>
104- d.child.exists {
105- case _ : CometNativeWriteExec => true
106- case _ => false
109+ d.child.foreach {
110+ case _ : CometNativeWriteExec =>
111+ nativeWriteCount += 1
112+ case _ =>
107113 }
108- case _ => false
114+ case _ =>
109115 }
110116
111117 assert(
112- hasNativeWrite ,
113- s " Expected CometNativeWriteExec in the plan, but got : \n ${executedPlan.treeString}" )
118+ nativeWriteCount == 1 ,
119+ s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount : \n ${executedPlan.treeString}" )
114120 }
115121 } finally {
116122 spark.listenerManager.unregister(listener)
@@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase {
197203 }
198204 }
199205 }
206+
207+ test(" basic parquet write with repartition" ) {
208+ withTempPath { dir =>
209+ // Create test data and write it to a temp parquet file first
210+ withTempPath { inputDir =>
211+ val inputPath = createTestData(inputDir)
212+ Seq (true , false ).foreach(adaptive => {
213+ // Create a new output path for each AQE value
214+ val outputPath = new File (dir, s " output_aqe_ $adaptive.parquet " ).getAbsolutePath
215+
216+ withSQLConf(
217+ CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " true" ,
218+ " spark.sql.adaptive.enabled" -> adaptive.toString,
219+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Halifax" ,
220+ CometConf .getOperatorAllowIncompatConfigKey(
221+ classOf [DataWritingCommandExec ]) -> " true" ,
222+ CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
223+
224+ writeWithCometNativeWriteExec(inputPath, outputPath, Some (10 ))
225+ verifyWrittenFile(outputPath)
226+ }
227+ })
228+ }
229+ }
230+ }
200231}
0 commit comments