Skip to content

Commit 46bb2b5

Browse files
kiszkueshin
authored andcommitted
[SPARK-23924][SQL] Add element_at function
## What changes were proposed in this pull request? The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one. This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki <[email protected]> Closes apache#21053 from kiszk/SPARK-23924.
1 parent d5bec48 commit 46bb2b5

File tree

7 files changed

+276
-24
lines changed

7 files changed

+276
-24
lines changed

python/pyspark/sql/functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,6 +1862,30 @@ def array_position(col, value):
18621862
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
18631863

18641864

1865+
@ignore_unicode_prefix
1866+
@since(2.4)
1867+
def element_at(col, extraction):
1868+
"""
1869+
Collection function: Returns element of array at given index in extraction if col is array.
1870+
Returns value for the given key in extraction if col is map.
1871+
1872+
:param col: name of column containing array or map
1873+
:param extraction: index to check for in array or key to check for in map
1874+
1875+
.. note:: The position is not zero based, but 1 based index.
1876+
1877+
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
1878+
>>> df.select(element_at(df.data, 1)).collect()
1879+
[Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
1880+
1881+
>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
1882+
>>> df.select(element_at(df.data, "a")).collect()
1883+
[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
1884+
"""
1885+
sc = SparkContext._active_spark_context
1886+
return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction))
1887+
1888+
18651889
@since(1.4)
18661890
def explode(col):
18671891
"""Returns a new row for each element in the given array or map.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ object FunctionRegistry {
405405
expression[ArrayPosition]("array_position"),
406406
expression[CreateMap]("map"),
407407
expression[CreateNamedStruct]("named_struct"),
408+
expression[ElementAt]("element_at"),
408409
expression[MapKeys]("map_keys"),
409410
expression[MapValues]("map_values"),
410411
expression[Size]("size"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression)
561561
})
562562
}
563563
}
564+
565+
/**
566+
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
567+
*/
568+
@ExpressionDescription(
569+
usage = """
570+
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
571+
accesses elements from the last to the first. Returns NULL if the index exceeds the length
572+
of the array.
573+
574+
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
575+
""",
576+
examples = """
577+
Examples:
578+
> SELECT _FUNC_(array(1, 2, 3), 2);
579+
2
580+
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
581+
"b"
582+
""",
583+
since = "2.4.0")
584+
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
585+
586+
override def dataType: DataType = left.dataType match {
587+
case ArrayType(elementType, _) => elementType
588+
case MapType(_, valueType, _) => valueType
589+
}
590+
591+
override def inputTypes: Seq[AbstractDataType] = {
592+
Seq(TypeCollection(ArrayType, MapType),
593+
left.dataType match {
594+
case _: ArrayType => IntegerType
595+
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
596+
}
597+
)
598+
}
599+
600+
override def nullable: Boolean = true
601+
602+
override def nullSafeEval(value: Any, ordinal: Any): Any = {
603+
left.dataType match {
604+
case _: ArrayType =>
605+
val array = value.asInstanceOf[ArrayData]
606+
val index = ordinal.asInstanceOf[Int]
607+
if (array.numElements() < math.abs(index)) {
608+
null
609+
} else {
610+
val idx = if (index == 0) {
611+
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
612+
} else if (index > 0) {
613+
index - 1
614+
} else {
615+
array.numElements() + index
616+
}
617+
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
618+
null
619+
} else {
620+
array.get(idx, dataType)
621+
}
622+
}
623+
case _: MapType =>
624+
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
625+
}
626+
}
627+
628+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
629+
left.dataType match {
630+
case _: ArrayType =>
631+
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
632+
val index = ctx.freshName("elementAtIndex")
633+
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
634+
s"""
635+
|if ($eval1.isNullAt($index)) {
636+
| ${ev.isNull} = true;
637+
|} else
638+
""".stripMargin
639+
} else {
640+
""
641+
}
642+
s"""
643+
|int $index = (int) $eval2;
644+
|if ($eval1.numElements() < Math.abs($index)) {
645+
| ${ev.isNull} = true;
646+
|} else {
647+
| if ($index == 0) {
648+
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
649+
| } else if ($index > 0) {
650+
| $index--;
651+
| } else {
652+
| $index += $eval1.numElements();
653+
| }
654+
| $nullCheck
655+
| {
656+
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
657+
| }
658+
|}
659+
""".stripMargin
660+
})
661+
case _: MapType =>
662+
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
663+
}
664+
}
665+
666+
override def prettyName: String = "element_at"
667+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
268268
}
269269

270270
/**
271-
* Returns the value of key `key` in Map `child`.
272-
*
273-
* We need to do type checking here as `key` expression maybe unresolved.
271+
* Common base class for [[GetMapValue]] and [[ElementAt]].
274272
*/
275-
case class GetMapValue(child: Expression, key: Expression)
276-
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {
277-
278-
private def keyType = child.dataType.asInstanceOf[MapType].keyType
279-
280-
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
281-
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
282-
283-
override def toString: String = s"$child[$key]"
284-
override def sql: String = s"${child.sql}[${key.sql}]"
285-
286-
override def left: Expression = child
287-
override def right: Expression = key
288-
289-
/** `Null` is returned for invalid ordinals. */
290-
override def nullable: Boolean = true
291-
292-
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
293273

274+
abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
294275
// todo: current search is O(n), improve it.
295-
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
276+
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
296277
val map = value.asInstanceOf[MapData]
297278
val length = map.numElements()
298279
val keys = map.keyArray()
@@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression)
315296
}
316297
}
317298

318-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
299+
def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
319300
val index = ctx.freshName("index")
320301
val length = ctx.freshName("length")
321302
val keys = ctx.freshName("keys")
322303
val found = ctx.freshName("found")
323304
val key = ctx.freshName("key")
324305
val values = ctx.freshName("values")
325-
val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) {
306+
val keyType = mapType.keyType
307+
val nullCheck = if (mapType.valueContainsNull) {
326308
s" || $values.isNullAt($index)"
327309
} else {
328310
""
@@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression)
354336
})
355337
}
356338
}
339+
340+
/**
341+
* Returns the value of key `key` in Map `child`.
342+
*
343+
* We need to do type checking here as `key` expression maybe unresolved.
344+
*/
345+
case class GetMapValue(child: Expression, key: Expression)
346+
extends GetMapValueUtil with ExtractValue with NullIntolerant {
347+
348+
private def keyType = child.dataType.asInstanceOf[MapType].keyType
349+
350+
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
351+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
352+
353+
override def toString: String = s"$child[$key]"
354+
override def sql: String = s"${child.sql}[${key.sql}]"
355+
356+
override def left: Expression = child
357+
override def right: Expression = key
358+
359+
/** `Null` is returned for invalid ordinals. */
360+
override def nullable: Boolean = true
361+
362+
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
363+
364+
// todo: current search is O(n), improve it.
365+
override def nullSafeEval(value: Any, ordinal: Any): Any = {
366+
getValueEval(value, ordinal, keyType)
367+
}
368+
369+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
370+
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
371+
}
372+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
191191
checkEvaluation(ArrayPosition(a3, Literal("")), null)
192192
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
193193
}
194+
195+
test("elementAt") {
196+
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
197+
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
198+
val a2 = Literal.create(Seq(null), ArrayType(LongType))
199+
val a3 = Literal.create(null, ArrayType(StringType))
200+
201+
intercept[Exception] {
202+
checkEvaluation(ElementAt(a0, Literal(0)), null)
203+
}.getMessage.contains("SQL array indices start at 1")
204+
intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
205+
checkEvaluation(ElementAt(a0, Literal(4)), null)
206+
checkEvaluation(ElementAt(a0, Literal(-4)), null)
207+
208+
checkEvaluation(ElementAt(a0, Literal(1)), 1)
209+
checkEvaluation(ElementAt(a0, Literal(2)), 2)
210+
checkEvaluation(ElementAt(a0, Literal(3)), 3)
211+
checkEvaluation(ElementAt(a0, Literal(-3)), 1)
212+
checkEvaluation(ElementAt(a0, Literal(-2)), 2)
213+
checkEvaluation(ElementAt(a0, Literal(-1)), 3)
214+
215+
checkEvaluation(ElementAt(a1, Literal(1)), null)
216+
checkEvaluation(ElementAt(a1, Literal(2)), "")
217+
checkEvaluation(ElementAt(a1, Literal(-2)), null)
218+
checkEvaluation(ElementAt(a1, Literal(-1)), "")
219+
220+
checkEvaluation(ElementAt(a2, Literal(1)), null)
221+
222+
checkEvaluation(ElementAt(a3, Literal(1)), null)
223+
224+
225+
val m0 =
226+
Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType))
227+
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
228+
val m2 = Literal.create(null, MapType(StringType, StringType))
229+
230+
checkEvaluation(ElementAt(m0, Literal(1.0)), null)
231+
232+
checkEvaluation(ElementAt(m0, Literal("d")), null)
233+
234+
checkEvaluation(ElementAt(m1, Literal("a")), null)
235+
236+
checkEvaluation(ElementAt(m0, Literal("a")), "1")
237+
checkEvaluation(ElementAt(m0, Literal("b")), "2")
238+
checkEvaluation(ElementAt(m0, Literal("c")), null)
239+
240+
checkEvaluation(ElementAt(m2, Literal("a")), null)
241+
}
194242
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,6 +3052,17 @@ object functions {
30523052
ArrayPosition(column.expr, Literal(value))
30533053
}
30543054

3055+
/**
3056+
* Returns element of array at given index in value if column is array. Returns value for
3057+
* the given key in value if column is map.
3058+
*
3059+
* @group collection_funcs
3060+
* @since 2.4.0
3061+
*/
3062+
def element_at(column: Column, value: Any): Column = withExpr {
3063+
ElementAt(column.expr, Literal(value))
3064+
}
3065+
30553066
/**
30563067
* Creates a new row for each element in the given array or map column.
30573068
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
569569
)
570570
}
571571

572+
test("element_at function") {
573+
val df = Seq(
574+
(Seq[String]("1", "2", "3")),
575+
(Seq[String](null, "")),
576+
(Seq[String]())
577+
).toDF("a")
578+
579+
intercept[Exception] {
580+
checkAnswer(
581+
df.select(element_at(df("a"), 0)),
582+
Seq(Row(null), Row(null), Row(null))
583+
)
584+
}.getMessage.contains("SQL array indices start at 1")
585+
intercept[Exception] {
586+
checkAnswer(
587+
df.select(element_at(df("a"), 1.1)),
588+
Seq(Row(null), Row(null), Row(null))
589+
)
590+
}
591+
checkAnswer(
592+
df.select(element_at(df("a"), 4)),
593+
Seq(Row(null), Row(null), Row(null))
594+
)
595+
596+
checkAnswer(
597+
df.select(element_at(df("a"), 1)),
598+
Seq(Row("1"), Row(null), Row(null))
599+
)
600+
checkAnswer(
601+
df.select(element_at(df("a"), -1)),
602+
Seq(Row("3"), Row(""), Row(null))
603+
)
604+
605+
checkAnswer(
606+
df.selectExpr("element_at(a, 4)"),
607+
Seq(Row(null), Row(null), Row(null))
608+
)
609+
610+
checkAnswer(
611+
df.selectExpr("element_at(a, 1)"),
612+
Seq(Row("1"), Row(null), Row(null))
613+
)
614+
checkAnswer(
615+
df.selectExpr("element_at(a, -1)"),
616+
Seq(Row("3"), Row(""), Row(null))
617+
)
618+
}
619+
572620
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
573621
import DataFrameFunctionsSuite.CodegenFallbackExpr
574622
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)