@@ -75,16 +75,15 @@ class CometParquetWriterSuite extends CometTestBase {
7575
7676 withSQLConf(CometConf .COMET_NATIVE_SCAN_IMPL .key -> " native_datafusion" ) {
7777 val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
78- capturedPlan.foreach { qe =>
79- val executedPlan = qe.executedPlan
80- val hasNativeScan = executedPlan.exists {
78+ capturedPlan.foreach { plan =>
79+ val hasNativeScan = plan.exists {
8180 case _ : CometNativeScanExec => true
8281 case _ => false
8382 }
8483
8584 assert(
8685 hasNativeScan,
87- s " Expected CometNativeScanExec in the plan, but got: \n ${executedPlan .treeString}" )
86+ s " Expected CometNativeScanExec in the plan, but got: \n ${plan .treeString}" )
8887 }
8988
9089 verifyWrittenFile(outputPath)
@@ -311,6 +310,54 @@ class CometParquetWriterSuite extends CometTestBase {
311310 }
312311 }
313312
313+ test(" native write falls back when scan produces non-Arrow data" ) {
314+ // This test verifies that when a native scan (like native_comet) doesn't support
315+ // certain data types (complex types), the native write correctly falls back to Spark
316+ // instead of failing at runtime with "Comet execution only takes Arrow Arrays" error.
317+ withTempPath { dir =>
318+ val inputPath = new File (dir, " input.parquet" ).getAbsolutePath
319+ val outputPath = new File (dir, " output.parquet" ).getAbsolutePath
320+
321+ // Create data with complex types and write without Comet
322+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
323+ val df = Seq ((1 , Seq (1 , 2 , 3 )), (2 , Seq (4 , 5 )), (3 , Seq (6 , 7 , 8 , 9 )))
324+ .toDF(" id" , " values" )
325+ df.write.parquet(inputPath)
326+ }
327+
328+ // With native Parquet write enabled but using native_comet scan which doesn't
329+ // support complex types, the scan falls back to Spark. The native write should
330+ // detect this and also fall back to Spark instead of failing at runtime.
331+ withSQLConf(
332+ CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " true" ,
333+ CometConf .COMET_EXEC_ENABLED .key -> " true" ,
334+ CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ,
335+ // Use native_comet which doesn't support complex types
336+ CometConf .COMET_NATIVE_SCAN_IMPL .key -> " native_comet" ) {
337+
338+ val plan =
339+ captureWritePlan(path => spark.read.parquet(inputPath).write.parquet(path), outputPath)
340+
341+ // Verify NO CometNativeWriteExec in the plan (should have fallen back to Spark)
342+ val hasNativeWrite = plan.exists {
343+ case _ : CometNativeWriteExec => true
344+ case d : DataWritingCommandExec =>
345+ d.child.exists(_.isInstanceOf [CometNativeWriteExec ])
346+ case _ => false
347+ }
348+
349+ assert(
350+ ! hasNativeWrite,
351+ " Expected fallback to Spark write (no CometNativeWriteExec), but found native write " +
352+ s " in plan: \n ${plan.treeString}" )
353+
354+ // Verify the data was written correctly
355+ val result = spark.read.parquet(outputPath).collect()
356+ assert(result.length == 3 , " Expected 3 rows to be written" )
357+ }
358+ }
359+ }
360+
314361 test(" parquet write complex types fuzz test" ) {
315362 withTempPath { dir =>
316363 val outputPath = new File (dir, " output.parquet" ).getAbsolutePath
@@ -347,18 +394,21 @@ class CometParquetWriterSuite extends CometTestBase {
347394 inputPath
348395 }
349396
350- private def writeWithCometNativeWriteExec (
351- inputPath : String ,
352- outputPath : String ,
353- num_partitions : Option [Int ] = None ): Option [QueryExecution ] = {
354- val df = spark.read.parquet(inputPath)
355-
356- // Use a listener to capture the execution plan during write
397+ /**
398+ * Captures the execution plan during a write operation.
399+ *
400+ * @param writeOp
401+ * The write operation to execute (takes output path as parameter)
402+ * @param outputPath
403+ * The path to write to
404+ * @return
405+ * The captured execution plan
406+ */
407+ private def captureWritePlan (writeOp : String => Unit , outputPath : String ): SparkPlan = {
357408 var capturedPlan : Option [QueryExecution ] = None
358409
359410 val listener = new org.apache.spark.sql.util.QueryExecutionListener {
360411 override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
361- // Capture plans from write operations
362412 if (funcName == " save" || funcName.contains(" command" )) {
363413 capturedPlan = Some (qe)
364414 }
@@ -373,8 +423,7 @@ class CometParquetWriterSuite extends CometTestBase {
373423 spark.listenerManager.register(listener)
374424
375425 try {
376- // Perform native write with optional partitioning
377- num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
426+ writeOp(outputPath)
378427
379428 // Wait for listener to be called with timeout
380429 val maxWaitTimeMs = 15000
@@ -387,36 +436,45 @@ class CometParquetWriterSuite extends CometTestBase {
387436 iterations += 1
388437 }
389438
390- // Verify that CometNativeWriteExec was used
391439 assert(
392440 capturedPlan.isDefined,
393441 s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
394442
395- capturedPlan.foreach { qe =>
396- val executedPlan = stripAQEPlan(qe.executedPlan)
443+ stripAQEPlan(capturedPlan.get.executedPlan)
444+ } finally {
445+ spark.listenerManager.unregister(listener)
446+ }
447+ }
397448
398- // Count CometNativeWriteExec instances in the plan
399- var nativeWriteCount = 0
400- executedPlan.foreach {
449+ private def writeWithCometNativeWriteExec (
450+ inputPath : String ,
451+ outputPath : String ,
452+ num_partitions : Option [Int ] = None ): Option [SparkPlan ] = {
453+ val df = spark.read.parquet(inputPath)
454+
455+ val plan = captureWritePlan(
456+ path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
457+ outputPath)
458+
459+ // Count CometNativeWriteExec instances in the plan
460+ var nativeWriteCount = 0
461+ plan.foreach {
462+ case _ : CometNativeWriteExec =>
463+ nativeWriteCount += 1
464+ case d : DataWritingCommandExec =>
465+ d.child.foreach {
401466 case _ : CometNativeWriteExec =>
402467 nativeWriteCount += 1
403- case d : DataWritingCommandExec =>
404- d.child.foreach {
405- case _ : CometNativeWriteExec =>
406- nativeWriteCount += 1
407- case _ =>
408- }
409468 case _ =>
410469 }
411-
412- assert(
413- nativeWriteCount == 1 ,
414- s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount: \n ${executedPlan.treeString}" )
415- }
416- } finally {
417- spark.listenerManager.unregister(listener)
470+ case _ =>
418471 }
419- capturedPlan
472+
473+ assert(
474+ nativeWriteCount == 1 ,
475+ s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount: \n ${plan.treeString}" )
476+
477+ Some (plan)
420478 }
421479
422480 private def verifyWrittenFile (outputPath : String ): Unit = {
0 commit comments