|
16 | 16 | */ |
17 | 17 | package org.apache.spark.sql |
18 | 18 |
|
| 19 | +import org.apache.spark.SparkException |
19 | 20 | import org.apache.spark.sql.functions._ |
| 21 | +import org.apache.spark.sql.internal.SQLConf |
| 22 | +import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructField, StructType} |
20 | 23 |
|
21 | 24 | class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenSQLTestsTrait { |
22 | 25 | import testImplicits._ |
@@ -49,4 +52,230 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS |
49 | 52 | false |
50 | 53 | ) |
51 | 54 | } |
| 55 | + |
| 56 | + testGluten("map_concat function") { |
| 57 | + val df1 = Seq( |
| 58 | + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), |
| 59 | + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), |
| 60 | + (null, Map[Int, Int](3 -> 300, 4 -> 400)) |
| 61 | + ).toDF("map1", "map2") |
| 62 | + |
| 63 | + val expected1a = Seq( |
| 64 | + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), |
| 65 | + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), |
| 66 | + Row(null) |
| 67 | + ) |
| 68 | + |
| 69 | + intercept[SparkException](df1.selectExpr("map_concat(map1, map2)").collect()) |
| 70 | + intercept[SparkException](df1.select(map_concat($"map1", $"map2")).collect()) |
| 71 | + withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { |
| 72 | + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) |
| 73 | + checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a) |
| 74 | + } |
| 75 | + |
| 76 | + val expected1b = Seq( |
| 77 | + Row(Map(1 -> 100, 2 -> 200)), |
| 78 | + Row(Map(1 -> 100, 2 -> 200)), |
| 79 | + Row(null) |
| 80 | + ) |
| 81 | + |
| 82 | + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) |
| 83 | + checkAnswer(df1.select(map_concat($"map1")), expected1b) |
| 84 | + |
| 85 | + val df2 = Seq( |
| 86 | + ( |
| 87 | + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), |
| 88 | + Map[String, Int]("3" -> 300, "4" -> 400) |
| 89 | + ) |
| 90 | + ).toDF("map1", "map2") |
| 91 | + |
| 92 | + val expected2 = Seq(Row(Map())) |
| 93 | + |
| 94 | + checkAnswer(df2.selectExpr("map_concat()"), expected2) |
| 95 | + checkAnswer(df2.select(map_concat()), expected2) |
| 96 | + |
| 97 | + val df3 = { |
| 98 | + val schema = StructType( |
| 99 | + StructField("map1", MapType(StringType, IntegerType, true), false) :: |
| 100 | + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil |
| 101 | + ) |
| 102 | + val data = Seq( |
| 103 | + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), |
| 104 | + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) |
| 105 | + ) |
| 106 | + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) |
| 107 | + } |
| 108 | + |
| 109 | + val expected3 = Seq( |
| 110 | + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), |
| 111 | + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) |
| 112 | + ) |
| 113 | + |
| 114 | + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) |
| 115 | + checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3) |
| 116 | + |
| 117 | + checkError( |
| 118 | + exception = intercept[AnalysisException] { |
| 119 | + df2.selectExpr("map_concat(map1, map2)").collect() |
| 120 | + }, |
| 121 | + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", |
| 122 | + sqlState = None, |
| 123 | + parameters = Map( |
| 124 | + "sqlExpr" -> "\"map_concat(map1, map2)\"", |
| 125 | + "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")", |
| 126 | + "functionName" -> "`map_concat`"), |
| 127 | + context = ExpectedContext(fragment = "map_concat(map1, map2)", start = 0, stop = 21) |
| 128 | + ) |
| 129 | + |
| 130 | + checkError( |
| 131 | + exception = intercept[AnalysisException] { |
| 132 | + df2.select(map_concat($"map1", $"map2")).collect() |
| 133 | + }, |
| 134 | + condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", |
| 135 | + sqlState = None, |
| 136 | + parameters = Map( |
| 137 | + "sqlExpr" -> "\"map_concat(map1, map2)\"", |
| 138 | + "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")", |
| 139 | + "functionName" -> "`map_concat`"), |
| 140 | + context = |
| 141 | + ExpectedContext(fragment = "map_concat", callSitePattern = getCurrentClassCallSitePattern) |
| 142 | + ) |
| 143 | + |
| 144 | + checkError( |
| 145 | + exception = intercept[AnalysisException] { |
| 146 | + df2.selectExpr("map_concat(map1, 12)").collect() |
| 147 | + }, |
| 148 | + condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", |
| 149 | + sqlState = None, |
| 150 | + parameters = Map( |
| 151 | + "sqlExpr" -> "\"map_concat(map1, 12)\"", |
| 152 | + "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]", |
| 153 | + "functionName" -> "`map_concat`"), |
| 154 | + context = ExpectedContext(fragment = "map_concat(map1, 12)", start = 0, stop = 19) |
| 155 | + ) |
| 156 | + |
| 157 | + checkError( |
| 158 | + exception = intercept[AnalysisException] { |
| 159 | + df2.select(map_concat($"map1", lit(12))).collect() |
| 160 | + }, |
| 161 | + condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", |
| 162 | + sqlState = None, |
| 163 | + parameters = Map( |
| 164 | + "sqlExpr" -> "\"map_concat(map1, 12)\"", |
| 165 | + "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]", |
| 166 | + "functionName" -> "`map_concat`"), |
| 167 | + context = |
| 168 | + ExpectedContext(fragment = "map_concat", callSitePattern = getCurrentClassCallSitePattern) |
| 169 | + ) |
| 170 | + } |
| 171 | + |
| 172 | + testGluten("transform keys function - primitive data types") { |
| 173 | + val dfExample1 = Seq( |
| 174 | + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) |
| 175 | + ).toDF("i") |
| 176 | + |
| 177 | + val dfExample2 = Seq( |
| 178 | + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) |
| 179 | + ).toDF("j") |
| 180 | + |
| 181 | + val dfExample3 = Seq( |
| 182 | + Map[Int, Boolean](25 -> true, 26 -> false) |
| 183 | + ).toDF("x") |
| 184 | + |
| 185 | + val dfExample4 = Seq( |
| 186 | + Map[Array[Int], Boolean](Array(1, 2) -> false) |
| 187 | + ).toDF("y") |
| 188 | + |
| 189 | + def testMapOfPrimitiveTypesCombination(): Unit = { |
| 190 | + checkAnswer( |
| 191 | + dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), |
| 192 | + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) |
| 193 | + |
| 194 | + checkAnswer( |
| 195 | + dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), |
| 196 | + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) |
| 197 | + |
| 198 | + checkAnswer( |
| 199 | + dfExample2.selectExpr( |
| 200 | + "transform_keys(j, " + |
| 201 | + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), |
| 202 | + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))) |
| 203 | + ) |
| 204 | + |
| 205 | + checkAnswer( |
| 206 | + dfExample2.select( |
| 207 | + transform_keys( |
| 208 | + col("j"), |
| 209 | + (k, v) => |
| 210 | + element_at( |
| 211 | + map_from_arrays( |
| 212 | + array(lit(1), lit(2), lit(3)), |
| 213 | + array(lit("one"), lit("two"), lit("three")) |
| 214 | + ), |
| 215 | + k |
| 216 | + ) |
| 217 | + ) |
| 218 | + ), |
| 219 | + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))) |
| 220 | + ) |
| 221 | + |
| 222 | + checkAnswer( |
| 223 | + dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), |
| 224 | + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) |
| 225 | + |
| 226 | + checkAnswer( |
| 227 | + dfExample2.select(transform_keys(col("j"), (k, v) => (v * 2).cast("bigint") + k)), |
| 228 | + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) |
| 229 | + |
| 230 | + checkAnswer( |
| 231 | + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), |
| 232 | + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) |
| 233 | + |
| 234 | + checkAnswer( |
| 235 | + dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), |
| 236 | + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) |
| 237 | + |
| 238 | + intercept[SparkException] { |
| 239 | + dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)").collect() |
| 240 | + } |
| 241 | + intercept[SparkException] { |
| 242 | + dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)).collect() |
| 243 | + } |
| 244 | + withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { |
| 245 | + checkAnswer( |
| 246 | + dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), |
| 247 | + Seq(Row(Map(true -> true, true -> false)))) |
| 248 | + |
| 249 | + checkAnswer( |
| 250 | + dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), |
| 251 | + Seq(Row(Map(true -> true, true -> false)))) |
| 252 | + } |
| 253 | + |
| 254 | + checkAnswer( |
| 255 | + dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), |
| 256 | + Seq(Row(Map(50 -> true, 78 -> false)))) |
| 257 | + |
| 258 | + checkAnswer( |
| 259 | + dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k * 2).otherwise(k * 3))), |
| 260 | + Seq(Row(Map(50 -> true, 78 -> false)))) |
| 261 | + |
| 262 | + checkAnswer( |
| 263 | + dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), |
| 264 | + Seq(Row(Map(false -> false)))) |
| 265 | + |
| 266 | + checkAnswer( |
| 267 | + dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k, lit(3)) && v)), |
| 268 | + Seq(Row(Map(false -> false)))) |
| 269 | + } |
| 270 | + |
| 271 | + // Test with local relation, the Project will be evaluated without codegen |
| 272 | + testMapOfPrimitiveTypesCombination() |
| 273 | + dfExample1.cache() |
| 274 | + dfExample2.cache() |
| 275 | + dfExample3.cache() |
| 276 | + dfExample4.cache() |
| 277 | + // Test with cached relation, the Project will be evaluated with codegen |
| 278 | + testMapOfPrimitiveTypesCombination() |
| 279 | + } |
| 280 | + |
52 | 281 | } |
0 commit comments