Skip to content

Commit d5bec48

Browse files
kiszkueshin
authored andcommitted
[SPARK-23919][SQL] Add array_position function
## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki <[email protected]> Closes apache#21037 from kiszk/SPARK-23919.
1 parent 8bb0df2 commit d5bec48

File tree

6 files changed

+144
-0
lines changed

6 files changed

+144
-0
lines changed

python/pyspark/sql/functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,23 @@ def array_contains(col, value):
18451845
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
18461846

18471847

1848+
@since(2.4)
1849+
def array_position(col, value):
1850+
"""
1851+
Collection function: Locates the position of the first occurrence of the given value
1852+
in the given array. Returns null if either of the arguments are null.
1853+
1854+
.. note:: The position is not zero based, but 1 based index. Returns 0 if the given
1855+
value could not be found in the array.
1856+
1857+
>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
1858+
>>> df.select(array_position(df.data, "a")).collect()
1859+
[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
1860+
"""
1861+
sc = SparkContext._active_spark_context
1862+
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
1863+
1864+
18481865
@since(1.4)
18491866
def explode(col):
18501867
"""Returns a new row for each element in the given array or map.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ object FunctionRegistry {
402402
// collection functions
403403
expression[CreateArray]("array"),
404404
expression[ArrayContains]("array_contains"),
405+
expression[ArrayPosition]("array_position"),
405406
expression[CreateMap]("map"),
406407
expression[CreateNamedStruct]("named_struct"),
407408
expression[MapKeys]("map_keys"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
505505

506506
override def prettyName: String = "array_max"
507507
}
508+
509+
510+
/**
511+
* Returns the position of the first occurrence of element in the given array as long.
512+
* Returns 0 if the given value could not be found in the array. Returns null if either of
513+
* the arguments are null
514+
*
515+
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
516+
* index 1.
517+
*/
518+
@ExpressionDescription(
519+
usage = """
520+
_FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
521+
""",
522+
examples = """
523+
Examples:
524+
> SELECT _FUNC_(array(3, 2, 1), 1);
525+
3
526+
""",
527+
since = "2.4.0")
528+
case class ArrayPosition(left: Expression, right: Expression)
529+
extends BinaryExpression with ImplicitCastInputTypes {
530+
531+
override def dataType: DataType = LongType
532+
override def inputTypes: Seq[AbstractDataType] =
533+
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
534+
535+
override def nullSafeEval(arr: Any, value: Any): Any = {
536+
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
537+
if (v == value) {
538+
return (i + 1).toLong
539+
}
540+
)
541+
0L
542+
}
543+
544+
override def prettyName: String = "array_position"
545+
546+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
547+
nullSafeCodeGen(ctx, ev, (arr, value) => {
548+
val pos = ctx.freshName("arrayPosition")
549+
val i = ctx.freshName("i")
550+
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
551+
s"""
552+
|int $pos = 0;
553+
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
554+
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
555+
| $pos = $i + 1;
556+
| break;
557+
| }
558+
|}
559+
|${ev.value} = (long) $pos;
560+
""".stripMargin
561+
})
562+
}
563+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
169169
checkEvaluation(Reverse(as7), null)
170170
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
171171
}
172+
173+
test("Array Position") {
174+
val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
175+
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
176+
val a2 = Literal.create(Seq(null), ArrayType(LongType))
177+
val a3 = Literal.create(null, ArrayType(StringType))
178+
179+
checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
180+
checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
181+
checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
182+
checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
183+
184+
checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
185+
checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
186+
checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
187+
188+
checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
189+
checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
190+
191+
checkEvaluation(ArrayPosition(a3, Literal("")), null)
192+
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
193+
}
172194
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,6 +3038,20 @@ object functions {
30383038
ArrayContains(column.expr, Literal(value))
30393039
}
30403040

3041+
/**
3042+
* Locates the position of the first occurrence of the value in the given array as long.
3043+
* Returns null if either of the arguments are null.
3044+
*
3045+
* @note The position is not zero based, but 1 based index. Returns 0 if value
3046+
* could not be found in array.
3047+
*
3048+
* @group collection_funcs
3049+
* @since 2.4.0
3050+
*/
3051+
def array_position(column: Column, value: Any): Column = withExpr {
3052+
ArrayPosition(column.expr, Literal(value))
3053+
}
3054+
30413055
/**
30423056
* Creates a new row for each element in the given array or map column.
30433057
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
535535
}
536536
}
537537

538+
test("array position function") {
539+
val df = Seq(
540+
(Seq[Int](1, 2), "x"),
541+
(Seq[Int](), "x")
542+
).toDF("a", "b")
543+
544+
checkAnswer(
545+
df.select(array_position(df("a"), 1)),
546+
Seq(Row(1L), Row(0L))
547+
)
548+
checkAnswer(
549+
df.selectExpr("array_position(a, 1)"),
550+
Seq(Row(1L), Row(0L))
551+
)
552+
553+
checkAnswer(
554+
df.select(array_position(df("a"), null)),
555+
Seq(Row(null), Row(null))
556+
)
557+
checkAnswer(
558+
df.selectExpr("array_position(a, null)"),
559+
Seq(Row(null), Row(null))
560+
)
561+
562+
checkAnswer(
563+
df.selectExpr("array_position(array(array(1), null)[0], 1)"),
564+
Seq(Row(1L), Row(1L))
565+
)
566+
checkAnswer(
567+
df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
568+
Seq(Row(1L), Row(1L))
569+
)
570+
}
571+
538572
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
539573
import DataFrameFunctionsSuite.CodegenFallbackExpr
540574
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)