Skip to content

Commit 14844a6

Browse files
mgaido91ueshin
authored andcommitted
[SPARK-23918][SQL] Add array_min function
## What changes were proposed in this pull request? The PR adds the SQL function `array_min`. It takes an array as argument and returns the minimum value in it. ## How was this patch tested? added UTs Author: Marco Gaido <[email protected]> Closes apache#21025 from mgaido91/SPARK-23918.
1 parent fd990a9 commit 14844a6

File tree

8 files changed

+131
-6
lines changed

8 files changed

+131
-6
lines changed

python/pyspark/sql/functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,21 @@ def size(col):
20802080
return Column(sc._jvm.functions.size(_to_java_column(col)))
20812081

20822082

2083+
@since(2.4)
2084+
def array_min(col):
2085+
"""
2086+
Collection function: returns the minimum value of the array.
2087+
2088+
:param col: name of column or expression
2089+
2090+
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
2091+
>>> df.select(array_min(df.data).alias('min')).collect()
2092+
[Row(min=1), Row(min=-1)]
2093+
"""
2094+
sc = SparkContext._active_spark_context
2095+
return Column(sc._jvm.functions.array_min(_to_java_column(col)))
2096+
2097+
20832098
@since(2.4)
20842099
def array_max(col):
20852100
"""
@@ -2108,7 +2123,7 @@ def sort_array(col, asc=True):
21082123
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
21092124
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
21102125
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
2111-
"""
2126+
"""
21122127
sc = SparkContext._active_spark_context
21132128
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
21142129

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ object FunctionRegistry {
409409
expression[MapValues]("map_values"),
410410
expression[Size]("size"),
411411
expression[SortArray]("sort_array"),
412+
expression[ArrayMin]("array_min"),
412413
expression[ArrayMax]("array_max"),
413414
CreateStruct.registryEntry,
414415

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends Expression {
595595
val evals = evalChildren.map(eval =>
596596
s"""
597597
|${eval.code}
598-
|if (!${eval.isNull} && (${ev.isNull} ||
599-
| ${ctx.genGreater(dataType, ev.value, eval.value)})) {
600-
| ${ev.isNull} = false;
601-
| ${ev.value} = ${eval.value};
602-
|}
598+
|${ctx.reassignIfSmaller(dataType, ev, eval)}
603599
""".stripMargin
604600
)
605601

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,23 @@ class CodegenContext {
699699
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
700700
}
701701

702+
/**
703+
* Generates code for updating `partialResult` if `item` is smaller than it.
704+
*
705+
* @param dataType data type of the expressions
706+
* @param partialResult `ExprCode` representing the partial result which has to be updated
707+
* @param item `ExprCode` representing the new expression to evaluate for the result
708+
*/
709+
def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
710+
s"""
711+
|if (!${item.isNull} && (${partialResult.isNull} ||
712+
| ${genGreater(dataType, partialResult.value, item.value)})) {
713+
| ${partialResult.isNull} = false;
714+
| ${partialResult.value} = ${item.value};
715+
|}
716+
""".stripMargin
717+
}
718+
702719
/**
703720
* Generates code for updating `partialResult` if `item` is greater than it.
704721
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: Expression)
288288
override def prettyName: String = "array_contains"
289289
}
290290

291+
/**
292+
* Returns the minimum value in the array.
293+
*/
294+
@ExpressionDescription(
295+
usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.",
296+
examples = """
297+
Examples:
298+
> SELECT _FUNC_(array(1, 20, null, 3));
299+
1
300+
""", since = "2.4.0")
301+
case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
302+
303+
override def nullable: Boolean = true
304+
305+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
306+
307+
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
308+
309+
override def checkInputDataTypes(): TypeCheckResult = {
310+
val typeCheckResult = super.checkInputDataTypes()
311+
if (typeCheckResult.isSuccess) {
312+
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
313+
} else {
314+
typeCheckResult
315+
}
316+
}
317+
318+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
319+
val childGen = child.genCode(ctx)
320+
val javaType = CodeGenerator.javaType(dataType)
321+
val i = ctx.freshName("i")
322+
val item = ExprCode("",
323+
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
324+
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
325+
ev.copy(code =
326+
s"""
327+
|${childGen.code}
328+
|boolean ${ev.isNull} = true;
329+
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
330+
|if (!${childGen.isNull}) {
331+
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
332+
| ${ctx.reassignIfSmaller(dataType, ev, item)}
333+
| }
334+
|}
335+
""".stripMargin)
336+
}
337+
338+
override protected def nullSafeEval(input: Any): Any = {
339+
var min: Any = null
340+
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
341+
if (item != null && (min == null || ordering.lt(item, min))) {
342+
min = item
343+
}
344+
)
345+
min
346+
}
347+
348+
override def dataType: DataType = child.dataType match {
349+
case ArrayType(dt, _) => dt
350+
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
351+
}
352+
353+
override def prettyName: String = "array_min"
354+
}
291355

292356
/**
293357
* Returns the maximum value in the array.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
106106
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
107107
}
108108

109+
test("Array Min") {
110+
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
111+
checkEvaluation(
112+
ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "")
113+
checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null)
114+
checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null)
115+
checkEvaluation(
116+
ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234)
117+
}
118+
109119
test("Array max") {
110120
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
111121
checkEvaluation(

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,14 @@ object functions {
33003300
*/
33013301
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
33023302

3303+
/**
3304+
* Returns the minimum value in the array.
3305+
*
3306+
* @group collection_funcs
3307+
* @since 2.4.0
3308+
*/
3309+
def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) }
3310+
33033311
/**
33043312
* Returns the maximum value in the array.
33053313
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
413413
)
414414
}
415415

416+
test("array_min function") {
417+
val df = Seq(
418+
Seq[Option[Int]](Some(1), Some(3), Some(2)),
419+
Seq.empty[Option[Int]],
420+
Seq[Option[Int]](None),
421+
Seq[Option[Int]](None, Some(1), Some(-100))
422+
).toDF("a")
423+
424+
val answer = Seq(Row(1), Row(null), Row(null), Row(-100))
425+
426+
checkAnswer(df.select(array_min(df("a"))), answer)
427+
checkAnswer(df.selectExpr("array_min(a)"), answer)
428+
}
429+
416430
test("array_max function") {
417431
val df = Seq(
418432
Seq[Option[Int]](Some(1), Some(3), Some(2)),

0 commit comments

Comments
 (0)