Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,14 @@ object Utils extends CometTypeShim {

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " +
"This typically happens when a Comet scan falls back to Spark due to unsupported " +
"data types (e.g., complex types like structs, arrays, or maps with native_comet). " +
"To resolve this, you can: " +
"(1) enable spark.comet.scan.allowIncompatible=true to use a compatible native " +
"scan variant, or " +
"(2) enable spark.comet.convert.parquet.enabled=true to convert Spark Parquet " +
"data to Arrow format automatically.")
}
}
(fieldVectors, provider)
Expand Down
15 changes: 14 additions & 1 deletion spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,20 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
w.child

case op: DataWritingCommandExec =>
convertToComet(op, CometDataWritingCommand).getOrElse(op)
// Get the actual data source child that will feed data to the native writer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more generic to do this in CometExecRule.convertToComet -

 if (op.children.nonEmpty && !op.children.forall(_.isInstanceOf[CometNativeExec])) {
    // For  operators like writes, require all children to be native
    if (requiresNativeChildren(handler)) {
      return None  // Fallback to Spark
    }
   ...
  }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks. I updated this.

val dataSourceChild = op.child match {
case writeFiles: WriteFilesExec => writeFiles.child
case other => other
}
// Only convert to native write if the data source produces Arrow data.
// If it's a Spark scan (not CometNativeExec), the native writer will fail at runtime
// because it expects Arrow arrays but will receive OnHeapColumnVector.
if (dataSourceChild.isInstanceOf[CometNativeExec]) {
convertToComet(op, CometDataWritingCommand).getOrElse(op)
} else {
withInfo(op, "Cannot perform native write because input is not in Arrow format")
op
}

// For AQE broadcast stage on a Comet broadcast exchange
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ class CometParquetWriterSuite extends CometTestBase {

withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
val hasNativeScan = executedPlan.exists {
capturedPlan.foreach { plan =>
val hasNativeScan = plan.exists {
case _: CometNativeScanExec => true
case _ => false
}

assert(
hasNativeScan,
s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}")
s"Expected CometNativeScanExec in the plan, but got:\n${plan.treeString}")
}

verifyWrittenFile(outputPath)
Expand Down Expand Up @@ -311,6 +310,54 @@ class CometParquetWriterSuite extends CometTestBase {
}
}

test("native write falls back when scan produces non-Arrow data") {
// This test verifies that when a native scan (like native_comet) doesn't support
// certain data types (complex types), the native write correctly falls back to Spark
// instead of failing at runtime with "Comet execution only takes Arrow Arrays" error.
withTempPath { dir =>
val inputPath = new File(dir, "input.parquet").getAbsolutePath
val outputPath = new File(dir, "output.parquet").getAbsolutePath

// Create data with complex types and write without Comet
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val df = Seq((1, Seq(1, 2, 3)), (2, Seq(4, 5)), (3, Seq(6, 7, 8, 9)))
.toDF("id", "values")
df.write.parquet(inputPath)
}

// With native Parquet write enabled but using native_comet scan which doesn't
// support complex types, the scan falls back to Spark. The native write should
// detect this and also fall back to Spark instead of failing at runtime.
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
// Use native_comet which doesn't support complex types
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_comet") {

val plan =
captureWritePlan(path => spark.read.parquet(inputPath).write.parquet(path), outputPath)

// Verify NO CometNativeWriteExec in the plan (should have fallen back to Spark)
val hasNativeWrite = plan.exists {
case _: CometNativeWriteExec => true
case d: DataWritingCommandExec =>
d.child.exists(_.isInstanceOf[CometNativeWriteExec])
case _ => false
}

assert(
!hasNativeWrite,
"Expected fallback to Spark write (no CometNativeWriteExec), but found native write " +
s"in plan:\n${plan.treeString}")

// Verify the data was written correctly
val result = spark.read.parquet(outputPath).collect()
assert(result.length == 3, "Expected 3 rows to be written")
}
}
}

test("parquet write complex types fuzz test") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath
Expand Down Expand Up @@ -347,18 +394,21 @@ class CometParquetWriterSuite extends CometTestBase {
inputPath
}

private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String,
num_partitions: Option[Int] = None): Option[QueryExecution] = {
val df = spark.read.parquet(inputPath)

// Use a listener to capture the execution plan during write
/**
* Captures the execution plan during a write operation.
*
* @param writeOp
* The write operation to execute (takes output path as parameter)
* @param outputPath
* The path to write to
* @return
* The captured execution plan
*/
private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = {
var capturedPlan: Option[QueryExecution] = None

val listener = new org.apache.spark.sql.util.QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
// Capture plans from write operations
if (funcName == "save" || funcName.contains("command")) {
capturedPlan = Some(qe)
}
Expand All @@ -373,8 +423,7 @@ class CometParquetWriterSuite extends CometTestBase {
spark.listenerManager.register(listener)

try {
// Perform native write with optional partitioning
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
writeOp(outputPath)

// Wait for listener to be called with timeout
val maxWaitTimeMs = 15000
Expand All @@ -387,36 +436,45 @@ class CometParquetWriterSuite extends CometTestBase {
iterations += 1
}

// Verify that CometNativeWriteExec was used
assert(
capturedPlan.isDefined,
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")

capturedPlan.foreach { qe =>
val executedPlan = stripAQEPlan(qe.executedPlan)
stripAQEPlan(capturedPlan.get.executedPlan)
} finally {
spark.listenerManager.unregister(listener)
}
}

// Count CometNativeWriteExec instances in the plan
var nativeWriteCount = 0
executedPlan.foreach {
private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String,
num_partitions: Option[Int] = None): Option[SparkPlan] = {
val df = spark.read.parquet(inputPath)

val plan = captureWritePlan(
path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
outputPath)

// Count CometNativeWriteExec instances in the plan
var nativeWriteCount = 0
plan.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case d: DataWritingCommandExec =>
d.child.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case d: DataWritingCommandExec =>
d.child.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case _ =>
}
case _ =>
}

assert(
nativeWriteCount == 1,
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
case _ =>
}
capturedPlan

assert(
nativeWriteCount == 1,
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")

Some(plan)
}

private def verifyWrittenFile(outputPath: String): Unit = {
Expand Down
Loading