Skip to content

Commit 8e55381

Browse files
authored
[GLUTEN-11088] Fix GlutenDataFrameFunctionsSuite in Spark-4.0 (#11195)
https://github.com/apache/spark/blob/29434ea766b0fc3c3bf6eaadb43a8f931133649e/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2928-L2937 Vanilla spark throw SparkRuntimeException, gluten throw SparkException. This patch modified the tests to adapt with Gluten code
1 parent c9f6d45 commit 8e55381

File tree

2 files changed

+230
-1
lines changed

2 files changed

+230
-1
lines changed

gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ class VeloxTestSettings extends BackendTestSettings {
758758
.exclude("aggregate function - array for non-primitive type")
759759
// Rewrite this test because Velox sorts rows by key for primitive data types, which disrupts the original row sequence.
760760
.exclude("map_zip_with function - map of primitive types")
761-
// TODO: fix in Spark-4.0
761+
// Vanilla spark throw SparkRuntimeException, gluten throw SparkException.
762762
.exclude("map_concat function")
763763
.exclude("transform keys function - primitive data types")
764764
enableSuite[GlutenDataFrameHintSuite]

gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
*/
1717
package org.apache.spark.sql
1818

19+
import org.apache.spark.SparkException
1920
import org.apache.spark.sql.functions._
21+
import org.apache.spark.sql.internal.SQLConf
22+
import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructField, StructType}
2023

2124
class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait {
2225
import testImplicits._
@@ -49,4 +52,230 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS
4952
false
5053
)
5154
}
55+
56+
testGluten("map_concat function") {
57+
val df1 = Seq(
58+
(Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)),
59+
(Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)),
60+
(null, Map[Int, Int](3 -> 300, 4 -> 400))
61+
).toDF("map1", "map2")
62+
63+
val expected1a = Seq(
64+
Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)),
65+
Row(Map(1 -> 400, 2 -> 200, 3 -> 300)),
66+
Row(null)
67+
)
68+
69+
intercept[SparkException](df1.selectExpr("map_concat(map1, map2)").collect())
70+
intercept[SparkException](df1.select(map_concat($"map1", $"map2")).collect())
71+
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
72+
checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
73+
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
74+
}
75+
76+
val expected1b = Seq(
77+
Row(Map(1 -> 100, 2 -> 200)),
78+
Row(Map(1 -> 100, 2 -> 200)),
79+
Row(null)
80+
)
81+
82+
checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b)
83+
checkAnswer(df1.select(map_concat($"map1")), expected1b)
84+
85+
val df2 = Seq(
86+
(
87+
Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200),
88+
Map[String, Int]("3" -> 300, "4" -> 400)
89+
)
90+
).toDF("map1", "map2")
91+
92+
val expected2 = Seq(Row(Map()))
93+
94+
checkAnswer(df2.selectExpr("map_concat()"), expected2)
95+
checkAnswer(df2.select(map_concat()), expected2)
96+
97+
val df3 = {
98+
val schema = StructType(
99+
StructField("map1", MapType(StringType, IntegerType, true), false) ::
100+
StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil
101+
)
102+
val data = Seq(
103+
Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)),
104+
Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4))
105+
)
106+
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
107+
}
108+
109+
val expected3 = Seq(
110+
Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)),
111+
Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4))
112+
)
113+
114+
checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3)
115+
checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3)
116+
117+
checkError(
118+
exception = intercept[AnalysisException] {
119+
df2.selectExpr("map_concat(map1, map2)").collect()
120+
},
121+
condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
122+
sqlState = None,
123+
parameters = Map(
124+
"sqlExpr" -> "\"map_concat(map1, map2)\"",
125+
"dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
126+
"functionName" -> "`map_concat`"),
127+
context = ExpectedContext(fragment = "map_concat(map1, map2)", start = 0, stop = 21)
128+
)
129+
130+
checkError(
131+
exception = intercept[AnalysisException] {
132+
df2.select(map_concat($"map1", $"map2")).collect()
133+
},
134+
condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
135+
sqlState = None,
136+
parameters = Map(
137+
"sqlExpr" -> "\"map_concat(map1, map2)\"",
138+
"dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
139+
"functionName" -> "`map_concat`"),
140+
context =
141+
ExpectedContext(fragment = "map_concat", callSitePattern = getCurrentClassCallSitePattern)
142+
)
143+
144+
checkError(
145+
exception = intercept[AnalysisException] {
146+
df2.selectExpr("map_concat(map1, 12)").collect()
147+
},
148+
condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
149+
sqlState = None,
150+
parameters = Map(
151+
"sqlExpr" -> "\"map_concat(map1, 12)\"",
152+
"dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
153+
"functionName" -> "`map_concat`"),
154+
context = ExpectedContext(fragment = "map_concat(map1, 12)", start = 0, stop = 19)
155+
)
156+
157+
checkError(
158+
exception = intercept[AnalysisException] {
159+
df2.select(map_concat($"map1", lit(12))).collect()
160+
},
161+
condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
162+
sqlState = None,
163+
parameters = Map(
164+
"sqlExpr" -> "\"map_concat(map1, 12)\"",
165+
"dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
166+
"functionName" -> "`map_concat`"),
167+
context =
168+
ExpectedContext(fragment = "map_concat", callSitePattern = getCurrentClassCallSitePattern)
169+
)
170+
}
171+
172+
testGluten("transform keys function - primitive data types") {
173+
val dfExample1 = Seq(
174+
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
175+
).toDF("i")
176+
177+
val dfExample2 = Seq(
178+
Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
179+
).toDF("j")
180+
181+
val dfExample3 = Seq(
182+
Map[Int, Boolean](25 -> true, 26 -> false)
183+
).toDF("x")
184+
185+
val dfExample4 = Seq(
186+
Map[Array[Int], Boolean](Array(1, 2) -> false)
187+
).toDF("y")
188+
189+
def testMapOfPrimitiveTypesCombination(): Unit = {
190+
checkAnswer(
191+
dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
192+
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
193+
194+
checkAnswer(
195+
dfExample1.select(transform_keys(col("i"), (k, v) => k + v)),
196+
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
197+
198+
checkAnswer(
199+
dfExample2.selectExpr(
200+
"transform_keys(j, " +
201+
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
202+
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
203+
)
204+
205+
checkAnswer(
206+
dfExample2.select(
207+
transform_keys(
208+
col("j"),
209+
(k, v) =>
210+
element_at(
211+
map_from_arrays(
212+
array(lit(1), lit(2), lit(3)),
213+
array(lit("one"), lit("two"), lit("three"))
214+
),
215+
k
216+
)
217+
)
218+
),
219+
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
220+
)
221+
222+
checkAnswer(
223+
dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
224+
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
225+
226+
checkAnswer(
227+
dfExample2.select(transform_keys(col("j"), (k, v) => (v * 2).cast("bigint") + k)),
228+
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
229+
230+
checkAnswer(
231+
dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
232+
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
233+
234+
checkAnswer(
235+
dfExample2.select(transform_keys(col("j"), (k, v) => k + v)),
236+
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
237+
238+
intercept[SparkException] {
239+
dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)").collect()
240+
}
241+
intercept[SparkException] {
242+
dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)).collect()
243+
}
244+
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
245+
checkAnswer(
246+
dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
247+
Seq(Row(Map(true -> true, true -> false))))
248+
249+
checkAnswer(
250+
dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)),
251+
Seq(Row(Map(true -> true, true -> false))))
252+
}
253+
254+
checkAnswer(
255+
dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
256+
Seq(Row(Map(50 -> true, 78 -> false))))
257+
258+
checkAnswer(
259+
dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k * 2).otherwise(k * 3))),
260+
Seq(Row(Map(50 -> true, 78 -> false))))
261+
262+
checkAnswer(
263+
dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"),
264+
Seq(Row(Map(false -> false))))
265+
266+
checkAnswer(
267+
dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k, lit(3)) && v)),
268+
Seq(Row(Map(false -> false))))
269+
}
270+
271+
// Test with local relation, the Project will be evaluated without codegen
272+
testMapOfPrimitiveTypesCombination()
273+
dfExample1.cache()
274+
dfExample2.cache()
275+
dfExample3.cache()
276+
dfExample4.cache()
277+
// Test with cached relation, the Project will be evaluated with codegen
278+
testMapOfPrimitiveTypesCombination()
279+
}
280+
52281
}

0 commit comments

Comments
 (0)