Skip to content

Commit 00fe67f

Browse files
committed
Add config to disable columnar shuffle for complex types
1 parent dca45ea commit 00fe67f

File tree

3 files changed

+88
-30
lines changed

3 files changed

+88
-30
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,17 @@ object CometConf extends ShimCometConf {
376376
.intConf
377377
.createWithDefault(1)
378378

379+
val COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED: ConfigEntry[Boolean] =
380+
conf("spark.comet.columnar.shuffle.complexTypes.enabled")
381+
.category(CATEGORY_SHUFFLE)
382+
.doc(
383+
"Whether to enable Comet columnar shuffle for complex types (struct, array, map). " +
384+
"When disabled (default), queries with complex types will fall back to Spark shuffle " +
385+
"for better performance. Enable this only if you need columnar shuffle features for " +
386+
"complex types and accept potential performance tradeoffs.")
387+
.booleanConf
388+
.createWithDefault(false)
389+
379390
val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] =
380391
conf("spark.comet.columnar.shuffle.async.enabled")
381392
.category(CATEGORY_SHUFFLE)

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
4949
import com.google.common.base.Objects
5050

5151
import org.apache.comet.CometConf
52-
import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE}
52+
import org.apache.comet.CometConf.{COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED, COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE}
5353
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo}
5454
import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported}
5555
import org.apache.comet.serde.operator.CometSink
@@ -410,13 +410,16 @@ object CometShuffleExchangeExec
410410
_: TimestampNTZType | _: DecimalType | _: DateType =>
411411
true
412412
case StructType(fields) =>
413+
COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf) &&
413414
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) &&
414415
// Java Arrow stream reader cannot work on duplicate field name
415416
fields.map(f => f.name).distinct.length == fields.length &&
416417
fields.nonEmpty
417418
case ArrayType(elementType, _) =>
419+
COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf) &&
418420
supportedSerializableDataType(elementType)
419421
case MapType(keyType, valueType, _) =>
422+
COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf) &&
420423
supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
421424
case _ =>
422425
false

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,42 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
112112
checkSparkAnswer(df)
113113
}
114114

115+
test("Fallback to Spark for complex types when config is disabled (default)") {
116+
// https://github.com/apache/datafusion-comet/issues/2904
117+
// By default, complex types should fall back to Spark shuffle for better performance
118+
withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "false") {
119+
// Test struct type
120+
withParquetTable(Seq((1, (0, "1")), (2, (3, "3"))), "tbl") {
121+
val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2")
122+
// Should have 0 Comet shuffle exchanges since complex types are disabled
123+
checkCometExchange(df, 0, false)
124+
checkSparkAnswer(df)
125+
}
126+
127+
// Test array type
128+
withParquetTable((0 until 10).map(i => (Seq(i, i + 1), i + 1)), "tbl2") {
129+
val df = sql("SELECT * FROM tbl2").repartition(10, $"_1", $"_2")
130+
checkCometExchange(df, 0, false)
131+
checkSparkAnswer(df)
132+
}
133+
134+
// Test map type
135+
withParquetTable((0 until 10).map(i => (Map(i -> i.toString), i + 1)), "tbl3") {
136+
val df = sql("SELECT * FROM tbl3").repartition(10, $"_1", $"_2")
137+
checkCometExchange(df, 0, false)
138+
checkSparkAnswer(df)
139+
}
140+
}
141+
}
142+
115143
test("columnar shuffle on nested struct including nulls") {
116144
// https://github.com/apache/datafusion-comet/issues/1538
117145
assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
118146
Seq(10, 201).foreach { numPartitions =>
119147
Seq("1.0", "10.0").foreach { ratio =>
120-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
148+
withSQLConf(
149+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
150+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
121151
withParquetTable(
122152
(0 until 50).map(i =>
123153
(i, Seq((i + 1, i.toString), null, (i + 3, (i + 3).toString)), i + 1)),
@@ -137,7 +167,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
137167
test("columnar shuffle on struct including nulls") {
138168
Seq(10, 201).foreach { numPartitions =>
139169
Seq("1.0", "10.0").foreach { ratio =>
140-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
170+
withSQLConf(
171+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
172+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
141173
val data: Seq[(Int, (Int, String))] =
142174
Seq((1, (0, "1")), (2, (3, "3")), (3, null))
143175
withParquetTable(data, "tbl") {
@@ -158,6 +190,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
158190
Seq(10, 201).foreach { numPartitions =>
159191
Seq("1.0", "10.0").foreach { ratio =>
160192
withSQLConf(
193+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
161194
CometConf.COMET_EXEC_ENABLED.key -> execEnabled,
162195
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
163196
withParquetTable((0 until 50).map(i => (Map(Seq(i, i + 1) -> i), i + 1)), "tbl") {
@@ -230,6 +263,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
230263
Seq(10, 201).foreach { numPartitions =>
231264
Seq("1.0", "10.0").foreach { ratio =>
232265
withSQLConf(
266+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
233267
CometConf.COMET_EXEC_ENABLED.key -> execEnabled,
234268
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
235269
withParquetTable(
@@ -336,7 +370,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
336370
def columnarShuffleOnMapTest[K: TypeTag](num: Int, keys: Seq[K]): Unit = {
337371
Seq(10, 201).foreach { numPartitions =>
338372
Seq("1.0", "10.0").foreach { ratio =>
339-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
373+
withSQLConf(
374+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
375+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
340376
withParquetTable(genTuples(num, keys), "tbl") {
341377
repartitionAndSort(numPartitions)
342378
}
@@ -451,7 +487,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
451487

452488
Seq(10, 201).foreach { numPartitions =>
453489
Seq("1.0", "10.0").foreach { ratio =>
454-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
490+
withSQLConf(
491+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
492+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
455493
withParquetTable(
456494
(0 until 50).map(i =>
457495
(
@@ -483,7 +521,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
483521
Seq("false", "true").foreach { _ =>
484522
Seq(10, 201).foreach { numPartitions =>
485523
Seq("1.0", "10.0").foreach { ratio =>
486-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
524+
withSQLConf(
525+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
526+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
487527
withParquetTable(
488528
(0 until 50).map(i => (Seq(Seq(i + 1), Seq(i + 2), Seq(i + 3)), i + 1)),
489529
"tbl") {
@@ -503,7 +543,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
503543
test("columnar shuffle on nested struct") {
504544
Seq(10, 201).foreach { numPartitions =>
505545
Seq("1.0", "10.0").foreach { ratio =>
506-
withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
546+
withSQLConf(
547+
CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true",
548+
CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) {
507549
withParquetTable(
508550
(0 until 50).map(i =>
509551
((i, 2.toString, (i + 1).toLong, (3.toString, i + 1, (i + 2).toLong)), i + 1)),
@@ -871,29 +913,31 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
871913
}
872914

873915
test("columnar shuffle on null struct fields") {
874-
withTempDir { dir =>
875-
val testData = "{}\n"
876-
val path = Paths.get(dir.toString, "test.json")
877-
Files.write(path, testData.getBytes)
878-
879-
// Define the nested struct schema
880-
val readSchema = StructType(
881-
Array(
882-
StructField(
883-
"metaData",
884-
StructType(
885-
Array(StructField(
886-
"format",
887-
StructType(Array(StructField("provider", StringType, nullable = true))),
888-
nullable = true))),
889-
nullable = true)))
890-
891-
// Read JSON with custom schema and repartition, this will repartition rows that contain
892-
// null struct fields.
893-
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
894-
assert(df.count() == 1)
895-
val row = df.collect()(0)
896-
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
916+
withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true") {
917+
withTempDir { dir =>
918+
val testData = "{}\n"
919+
val path = Paths.get(dir.toString, "test.json")
920+
Files.write(path, testData.getBytes)
921+
922+
// Define the nested struct schema
923+
val readSchema = StructType(
924+
Array(
925+
StructField(
926+
"metaData",
927+
StructType(
928+
Array(StructField(
929+
"format",
930+
StructType(Array(StructField("provider", StringType, nullable = true))),
931+
nullable = true))),
932+
nullable = true)))
933+
934+
// Read JSON with custom schema and repartition, this will repartition rows that contain
935+
// null struct fields.
936+
val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
937+
assert(df.count() == 1)
938+
val row = df.collect()(0)
939+
assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
940+
}
897941
}
898942
}
899943

0 commit comments

Comments
 (0)