Skip to content

Commit f40a396

Browse files
authored
fix: Prevent native write when input is not Arrow format (#3227)
1 parent deaec6f commit f40a396

File tree

5 files changed

+126
-36
lines changed

5 files changed

+126
-36
lines changed

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,14 @@ object Utils extends CometTypeShim {
269269

270270
case c =>
271271
throw new SparkException(
272-
"Comet execution only takes Arrow Arrays, but got " +
273-
s"${c.getClass}")
272+
s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " +
273+
"This typically happens when a Comet scan falls back to Spark due to unsupported " +
274+
"data types (e.g., complex types like structs, arrays, or maps with native_comet). " +
275+
"To resolve this, you can: " +
276+
"(1) enable spark.comet.scan.allowIncompatible=true to use a compatible native " +
277+
"scan variant, or " +
278+
"(2) enable spark.comet.convert.parquet.enabled=true to convert Spark Parquet " +
279+
"data to Arrow format automatically.")
274280
}
275281
}
276282
(fieldVectors, provider)

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
482482
private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = {
483483
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
484484
if (isOperatorEnabled(serde, op)) {
485+
// For operators that require native children (like writes), check if all data-producing
486+
// children are CometNativeExec. This prevents runtime failures when the native operator
487+
// expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector).
488+
if (serde.requiresNativeChildren && op.children.nonEmpty) {
489+
// Get the actual data-producing children (unwrap WriteFilesExec if present)
490+
val dataProducingChildren = op.children.flatMap {
491+
case writeFiles: WriteFilesExec => Seq(writeFiles.child)
492+
case other => Seq(other)
493+
}
494+
if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) {
495+
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
496+
return None
497+
}
498+
}
499+
485500
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
486501
if (op.children.nonEmpty && op.children.forall(_.isInstanceOf[CometNativeExec])) {
487502
val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)

spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ trait CometOperatorSerde[T <: SparkPlan] {
3636
*/
3737
def enabledConfig: Option[ConfigEntry[Boolean]]
3838

39+
/**
40+
* Indicates whether this operator requires all of its children to be CometNativeExec. If true
41+
* and any child is not a native exec, conversion will be skipped and the operator will fall
42+
* back to Spark. This is useful for operators like writes that require Arrow-formatted input.
43+
*/
44+
def requiresNativeChildren: Boolean = false
45+
3946
/**
4047
* Determine the support level of the operator based on its attributes.
4148
*

spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
4949
override def enabledConfig: Option[ConfigEntry[Boolean]] =
5050
Some(CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED)
5151

52+
// Native writes require Arrow-formatted input data. If the scan falls back to Spark
53+
// (e.g., due to unsupported complex types), the write must also fall back.
54+
override def requiresNativeChildren: Boolean = true
55+
5256
override def getSupportLevel(op: DataWritingCommandExec): SupportLevel = {
5357
op.cmd match {
5458
case cmd: InsertIntoHadoopFsRelationCommand =>

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 92 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ class CometParquetWriterSuite extends CometTestBase {
7575

7676
withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
7777
val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
78-
capturedPlan.foreach { qe =>
79-
val executedPlan = qe.executedPlan
80-
val hasNativeScan = executedPlan.exists {
78+
capturedPlan.foreach { plan =>
79+
val hasNativeScan = plan.exists {
8180
case _: CometNativeScanExec => true
8281
case _ => false
8382
}
8483

8584
assert(
8685
hasNativeScan,
87-
s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}")
86+
s"Expected CometNativeScanExec in the plan, but got:\n${plan.treeString}")
8887
}
8988

9089
verifyWrittenFile(outputPath)
@@ -311,6 +310,54 @@ class CometParquetWriterSuite extends CometTestBase {
311310
}
312311
}
313312

313+
test("native write falls back when scan produces non-Arrow data") {
314+
// This test verifies that when a native scan (like native_comet) doesn't support
315+
// certain data types (complex types), the native write correctly falls back to Spark
316+
// instead of failing at runtime with "Comet execution only takes Arrow Arrays" error.
317+
withTempPath { dir =>
318+
val inputPath = new File(dir, "input.parquet").getAbsolutePath
319+
val outputPath = new File(dir, "output.parquet").getAbsolutePath
320+
321+
// Create data with complex types and write without Comet
322+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
323+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq(4, 5)), (3, Seq(6, 7, 8, 9)))
324+
.toDF("id", "values")
325+
df.write.parquet(inputPath)
326+
}
327+
328+
// With native Parquet write enabled but using native_comet scan which doesn't
329+
// support complex types, the scan falls back to Spark. The native write should
330+
// detect this and also fall back to Spark instead of failing at runtime.
331+
withSQLConf(
332+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
333+
CometConf.COMET_EXEC_ENABLED.key -> "true",
334+
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
335+
// Use native_comet which doesn't support complex types
336+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_comet") {
337+
338+
val plan =
339+
captureWritePlan(path => spark.read.parquet(inputPath).write.parquet(path), outputPath)
340+
341+
// Verify NO CometNativeWriteExec in the plan (should have fallen back to Spark)
342+
val hasNativeWrite = plan.exists {
343+
case _: CometNativeWriteExec => true
344+
case d: DataWritingCommandExec =>
345+
d.child.exists(_.isInstanceOf[CometNativeWriteExec])
346+
case _ => false
347+
}
348+
349+
assert(
350+
!hasNativeWrite,
351+
"Expected fallback to Spark write (no CometNativeWriteExec), but found native write " +
352+
s"in plan:\n${plan.treeString}")
353+
354+
// Verify the data was written correctly
355+
val result = spark.read.parquet(outputPath).collect()
356+
assert(result.length == 3, "Expected 3 rows to be written")
357+
}
358+
}
359+
}
360+
314361
test("parquet write complex types fuzz test") {
315362
withTempPath { dir =>
316363
val outputPath = new File(dir, "output.parquet").getAbsolutePath
@@ -347,18 +394,21 @@ class CometParquetWriterSuite extends CometTestBase {
347394
inputPath
348395
}
349396

350-
private def writeWithCometNativeWriteExec(
351-
inputPath: String,
352-
outputPath: String,
353-
num_partitions: Option[Int] = None): Option[QueryExecution] = {
354-
val df = spark.read.parquet(inputPath)
355-
356-
// Use a listener to capture the execution plan during write
397+
/**
398+
* Captures the execution plan during a write operation.
399+
*
400+
* @param writeOp
401+
* The write operation to execute (takes output path as parameter)
402+
* @param outputPath
403+
* The path to write to
404+
* @return
405+
* The captured execution plan
406+
*/
407+
private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = {
357408
var capturedPlan: Option[QueryExecution] = None
358409

359410
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
360411
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
361-
// Capture plans from write operations
362412
if (funcName == "save" || funcName.contains("command")) {
363413
capturedPlan = Some(qe)
364414
}
@@ -373,8 +423,7 @@ class CometParquetWriterSuite extends CometTestBase {
373423
spark.listenerManager.register(listener)
374424

375425
try {
376-
// Perform native write with optional partitioning
377-
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
426+
writeOp(outputPath)
378427

379428
// Wait for listener to be called with timeout
380429
val maxWaitTimeMs = 15000
@@ -387,36 +436,45 @@ class CometParquetWriterSuite extends CometTestBase {
387436
iterations += 1
388437
}
389438

390-
// Verify that CometNativeWriteExec was used
391439
assert(
392440
capturedPlan.isDefined,
393441
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
394442

395-
capturedPlan.foreach { qe =>
396-
val executedPlan = stripAQEPlan(qe.executedPlan)
443+
stripAQEPlan(capturedPlan.get.executedPlan)
444+
} finally {
445+
spark.listenerManager.unregister(listener)
446+
}
447+
}
397448

398-
// Count CometNativeWriteExec instances in the plan
399-
var nativeWriteCount = 0
400-
executedPlan.foreach {
449+
private def writeWithCometNativeWriteExec(
450+
inputPath: String,
451+
outputPath: String,
452+
num_partitions: Option[Int] = None): Option[SparkPlan] = {
453+
val df = spark.read.parquet(inputPath)
454+
455+
val plan = captureWritePlan(
456+
path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
457+
outputPath)
458+
459+
// Count CometNativeWriteExec instances in the plan
460+
var nativeWriteCount = 0
461+
plan.foreach {
462+
case _: CometNativeWriteExec =>
463+
nativeWriteCount += 1
464+
case d: DataWritingCommandExec =>
465+
d.child.foreach {
401466
case _: CometNativeWriteExec =>
402467
nativeWriteCount += 1
403-
case d: DataWritingCommandExec =>
404-
d.child.foreach {
405-
case _: CometNativeWriteExec =>
406-
nativeWriteCount += 1
407-
case _ =>
408-
}
409468
case _ =>
410469
}
411-
412-
assert(
413-
nativeWriteCount == 1,
414-
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
415-
}
416-
} finally {
417-
spark.listenerManager.unregister(listener)
470+
case _ =>
418471
}
419-
capturedPlan
472+
473+
assert(
474+
nativeWriteCount == 1,
475+
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")
476+
477+
Some(plan)
420478
}
421479

422480
private def verifyWrittenFile(outputPath: String): Unit = {

0 commit comments

Comments
 (0)