1919
2020package org .apache .comet .exec
2121
22- import java .nio .file .Files
23- import java .nio .file .Paths
22+ import java .nio .file .{Files , Paths }
2423
2524import scala .reflect .runtime .universe ._
2625import scala .util .Random
@@ -168,8 +167,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
168167 .repartition(numPartitions, $" _1" , $" _2" )
169168 .sortWithinPartitions($" _2" )
170169
171- // Array map key array element fallback to Spark shuffle for now
172- checkShuffleAnswer(df, 0 )
170+ if (isSpark40Plus) {
171+ // https://github.com/apache/datafusion-comet/issues/1941
172+ // Spark 4.0 introduces a mapsort which falls back
173+ checkShuffleAnswer(df, 0 )
174+ } else {
175+ checkShuffleAnswer(df, 1 )
176+ }
173177 }
174178
175179 withParquetTable((0 until 50 ).map(i => (Map (i -> Seq (i, i + 1 )), i + 1 )), " tbl" ) {
@@ -178,8 +182,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
178182 .repartition(numPartitions, $" _1" , $" _2" )
179183 .sortWithinPartitions($" _2" )
180184
181- // Array map value array element fallback to Spark shuffle for now
182- checkShuffleAnswer(df, 0 )
185+ if (isSpark40Plus) {
186+ // https://github.com/apache/datafusion-comet/issues/1941
187+ // Spark 4.0 introduces a mapsort which falls back
188+ checkShuffleAnswer(df, 0 )
189+ } else {
190+ checkShuffleAnswer(df, 1 )
191+ }
183192 }
184193
185194 withParquetTable((0 until 50 ).map(i => (Map ((i, i.toString) -> i), i + 1 )), " tbl" ) {
@@ -188,8 +197,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
188197 .repartition(numPartitions, $" _1" , $" _2" )
189198 .sortWithinPartitions($" _2" )
190199
191- // Struct map key array element fallback to Spark shuffle for now
192- checkShuffleAnswer(df, 0 )
200+ if (isSpark40Plus) {
201+ // https://github.com/apache/datafusion-comet/issues/1941
202+ // Spark 4.0 introduces a mapsort which falls back
203+ checkShuffleAnswer(df, 0 )
204+ } else {
205+ checkShuffleAnswer(df, 1 )
206+ }
193207 }
194208
195209 withParquetTable((0 until 50 ).map(i => (Map (i -> (i, i.toString)), i + 1 )), " tbl" ) {
@@ -198,8 +212,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
198212 .repartition(numPartitions, $" _1" , $" _2" )
199213 .sortWithinPartitions($" _2" )
200214
201- // Struct map value array element fallback to Spark shuffle for now
202- checkShuffleAnswer(df, 0 )
215+ if (isSpark40Plus) {
216+ // https://github.com/apache/datafusion-comet/issues/1941
217+ // Spark 4.0 introduces a mapsort which falls back
218+ checkShuffleAnswer(df, 0 )
219+ } else {
220+ checkShuffleAnswer(df, 1 )
221+ }
203222 }
204223 }
205224 }
@@ -222,8 +241,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
222241 .repartition(numPartitions, $" _1" , $" _2" )
223242 .sortWithinPartitions($" _2" )
224243
225- // Map array element fallback to Spark shuffle for now
226- checkShuffleAnswer(df, 0 )
244+ if (isSpark40Plus) {
245+ // https://github.com/apache/datafusion-comet/issues/1941
246+ // Spark 4.0 introduces a mapsort which falls back
247+ checkShuffleAnswer(df, 0 )
248+ } else {
249+ checkShuffleAnswer(df, 1 )
250+ }
227251 }
228252 }
229253 }
@@ -469,8 +493,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
469493 .repartition(numPartitions, $" _1" , $" _2" )
470494 .sortWithinPartitions($" _1" )
471495
472- // Nested array fallback to Spark shuffle for now
473- checkShuffleAnswer(df, 0 )
496+ checkShuffleAnswer(df, 1 )
474497 }
475498 }
476499 }
@@ -983,6 +1006,7 @@ class DisableAQECometAsyncShuffleSuite extends CometColumnarShuffleSuite {
9831006}
9841007
9851008class CometShuffleEncryptionSuite extends CometTestBase {
1009+
9861010 import testImplicits ._
9871011
9881012 override protected def sparkConf : SparkConf = {
@@ -1034,6 +1058,7 @@ class CometShuffleManagerSuite extends CometTestBase {
10341058 shuffleWriterProcessor = null ,
10351059 partitioner = new Partitioner {
10361060 override def numPartitions : Int = 50
1061+
10371062 override def getPartition (key : Any ): Int = key.asInstanceOf [Int ]
10381063 },
10391064 decodeTime = null )
0 commit comments