Skip to content

Commit 7236e75

Browse files
chongguangHyukjinKwon
authored andcommitted
[SPARK-24574][SQL] array_contains, array_position, array_remove and element_at functions deal with Column type
## What changes were proposed in this pull request? For the function ```def array_contains(column: Column, value: Any): Column ``` , if we pass the `value` parameter as a Column type, it will yield a runtime exception. This PR proposes a pattern matching to detect if `value` is of type Column. If yes, it will use the .expr of the column, otherwise it will work as it used to. Same thing for ```array_position, array_remove and element_at``` functions ## How was this patch tested? Unit test modified to cover this code change. Ping ueshin Author: Chongguang LIU <[email protected]> Closes apache#21581 from chongguang/SPARK-24574.
1 parent 54fcaaf commit 7236e75

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,7 +3093,7 @@ object functions {
30933093
* @since 1.5.0
30943094
*/
30953095
def array_contains(column: Column, value: Any): Column = withExpr {
3096-
ArrayContains(column.expr, Literal(value))
3096+
ArrayContains(column.expr, lit(value).expr)
30973097
}
30983098

30993099
/**
@@ -3157,7 +3157,7 @@ object functions {
31573157
* @since 2.4.0
31583158
*/
31593159
def array_position(column: Column, value: Any): Column = withExpr {
3160-
ArrayPosition(column.expr, Literal(value))
3160+
ArrayPosition(column.expr, lit(value).expr)
31613161
}
31623162

31633163
/**
@@ -3168,7 +3168,7 @@ object functions {
31683168
* @since 2.4.0
31693169
*/
31703170
def element_at(column: Column, value: Any): Column = withExpr {
3171-
ElementAt(column.expr, Literal(value))
3171+
ElementAt(column.expr, lit(value).expr)
31723172
}
31733173

31743174
/**
@@ -3186,7 +3186,7 @@ object functions {
31863186
* @since 2.4.0
31873187
*/
31883188
def array_remove(column: Column, element: Any): Column = withExpr {
3189-
ArrayRemove(column.expr, Literal(element))
3189+
ArrayRemove(column.expr, lit(element).expr)
31903190
}
31913191

31923192
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
635635

636636
test("array contains function") {
637637
val df = Seq(
638-
(Seq[Int](1, 2), "x"),
639-
(Seq[Int](), "x")
640-
).toDF("a", "b")
638+
(Seq[Int](1, 2), "x", 1),
639+
(Seq[Int](), "x", 1)
640+
).toDF("a", "b", "c")
641641

642642
// Simple test cases
643643
checkAnswer(
@@ -648,6 +648,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
648648
df.selectExpr("array_contains(a, 1)"),
649649
Seq(Row(true), Row(false))
650650
)
651+
checkAnswer(
652+
df.select(array_contains(df("a"), df("c"))),
653+
Seq(Row(true), Row(false))
654+
)
655+
checkAnswer(
656+
df.selectExpr("array_contains(a, c)"),
657+
Seq(Row(true), Row(false))
658+
)
651659

652660
// In hive, this errors because null has no type information
653661
intercept[AnalysisException] {
@@ -862,9 +870,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
862870

863871
test("array position function") {
864872
val df = Seq(
865-
(Seq[Int](1, 2), "x"),
866-
(Seq[Int](), "x")
867-
).toDF("a", "b")
873+
(Seq[Int](1, 2), "x", 1),
874+
(Seq[Int](), "x", 1)
875+
).toDF("a", "b", "c")
868876

869877
checkAnswer(
870878
df.select(array_position(df("a"), 1)),
@@ -874,7 +882,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
874882
df.selectExpr("array_position(a, 1)"),
875883
Seq(Row(1L), Row(0L))
876884
)
877-
885+
checkAnswer(
886+
df.selectExpr("array_position(a, c)"),
887+
Seq(Row(1L), Row(0L))
888+
)
889+
checkAnswer(
890+
df.select(array_position(df("a"), df("c"))),
891+
Seq(Row(1L), Row(0L))
892+
)
878893
checkAnswer(
879894
df.select(array_position(df("a"), null)),
880895
Seq(Row(null), Row(null))
@@ -901,10 +916,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
901916

902917
test("element_at function") {
903918
val df = Seq(
904-
(Seq[String]("1", "2", "3")),
905-
(Seq[String](null, "")),
906-
(Seq[String]())
907-
).toDF("a")
919+
(Seq[String]("1", "2", "3"), 1),
920+
(Seq[String](null, ""), -1),
921+
(Seq[String](), 2)
922+
).toDF("a", "b")
908923

909924
intercept[Exception] {
910925
checkAnswer(
@@ -922,6 +937,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
922937
df.select(element_at(df("a"), 4)),
923938
Seq(Row(null), Row(null), Row(null))
924939
)
940+
checkAnswer(
941+
df.select(element_at(df("a"), df("b"))),
942+
Seq(Row("1"), Row(""), Row(null))
943+
)
944+
checkAnswer(
945+
df.selectExpr("element_at(a, b)"),
946+
Seq(Row("1"), Row(""), Row(null))
947+
)
925948

926949
checkAnswer(
927950
df.select(element_at(df("a"), 1)),
@@ -1189,10 +1212,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
11891212

11901213
test("array remove") {
11911214
val df = Seq(
1192-
(Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")),
1193-
(Array.empty[Int], Array.empty[String], Array.empty[String]),
1194-
(null, null, null)
1195-
).toDF("a", "b", "c")
1215+
(Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2),
1216+
(Array.empty[Int], Array.empty[String], Array.empty[String], 2),
1217+
(null, null, null, 2)
1218+
).toDF("a", "b", "c", "d")
11961219
checkAnswer(
11971220
df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")),
11981221
Seq(
@@ -1201,6 +1224,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
12011224
Row(null, null, null))
12021225
)
12031226

1227+
checkAnswer(
1228+
df.select(array_remove($"a", $"d")),
1229+
Seq(
1230+
Row(Seq(1, 3)),
1231+
Row(Seq.empty[Int]),
1232+
Row(null))
1233+
)
1234+
1235+
checkAnswer(
1236+
df.selectExpr("array_remove(a, d)"),
1237+
Seq(
1238+
Row(Seq(1, 3)),
1239+
Row(Seq.empty[Int]),
1240+
Row(null))
1241+
)
1242+
12041243
checkAnswer(
12051244
df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")",
12061245
"array_remove(c, \"\")"),

0 commit comments

Comments
 (0)