Skip to content

Commit 0ecc132

Browse files
committed
[SPARK-23909][SQL] Add filter function.
## What changes were proposed in this pull request? This pr adds `filter` function which filters the input array using the given predicate. ```sql > SELECT filter(array(1, 2, 3), x -> x % 2 == 1); array(1, 3) ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN <[email protected]> Closes apache#21965 from ueshin/issues/SPARK-23909/filter.
1 parent 36ea55e commit 0ecc132

File tree

6 files changed

+240
-9
lines changed

6 files changed

+240
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ object FunctionRegistry {
441441
expression[ArrayRemove]("array_remove"),
442442
expression[ArrayDistinct]("array_distinct"),
443443
expression[ArrayTransform]("transform"),
444+
expression[ArrayFilter]("filter"),
444445
CreateStruct.registryEntry,
445446

446447
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.util.concurrent.atomic.AtomicReference
2121

22+
import scala.collection.mutable
23+
2224
import org.apache.spark.sql.catalyst.InternalRow
2325
import org.apache.spark.sql.catalyst.expressions.codegen._
2426
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -140,6 +142,18 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu
140142
@transient lazy val functionForEval: Expression = functionsForEval.head
141143
}
142144

145+
object ArrayBasedHigherOrderFunction {
146+
147+
def elementArgumentType(dt: DataType): (DataType, Boolean) = {
148+
dt match {
149+
case ArrayType(elementType, containsNull) => (elementType, containsNull)
150+
case _ =>
151+
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
152+
(elementType, containsNull)
153+
}
154+
}
155+
}
156+
143157
/**
144158
* Transform elements in an array using the transform function. This is similar to
145159
* a `map` in functional programming.
@@ -164,17 +178,12 @@ case class ArrayTransform(
164178
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
165179

166180
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
167-
val (elementType, containsNull) = input.dataType match {
168-
case ArrayType(elementType, containsNull) => (elementType, containsNull)
169-
case _ =>
170-
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
171-
(elementType, containsNull)
172-
}
181+
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
173182
function match {
174183
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
175-
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
184+
copy(function = f(function, elem :: (IntegerType, false) :: Nil))
176185
case _ =>
177-
copy(function = f(function, (elementType, containsNull) :: Nil))
186+
copy(function = f(function, elem :: Nil))
178187
}
179188
}
180189

@@ -210,3 +219,54 @@ case class ArrayTransform(
210219

211220
override def prettyName: String = "transform"
212221
}
222+
223+
/**
224+
* Filters the input array using the given lambda function.
225+
*/
226+
@ExpressionDescription(
227+
usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.",
228+
examples = """
229+
Examples:
230+
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1);
231+
array(1, 3)
232+
""",
233+
since = "2.4.0")
234+
case class ArrayFilter(
235+
input: Expression,
236+
function: Expression)
237+
extends ArrayBasedHigherOrderFunction with CodegenFallback {
238+
239+
override def nullable: Boolean = input.nullable
240+
241+
override def dataType: DataType = input.dataType
242+
243+
override def expectingFunctionType: AbstractDataType = BooleanType
244+
245+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
246+
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
247+
copy(function = f(function, elem :: Nil))
248+
}
249+
250+
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
251+
252+
override def eval(input: InternalRow): Any = {
253+
val arr = this.input.eval(input).asInstanceOf[ArrayData]
254+
if (arr == null) {
255+
null
256+
} else {
257+
val f = functionForEval
258+
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
259+
var i = 0
260+
while (i < arr.numElements) {
261+
elementVar.value.set(arr.get(i, elementVar.dataType))
262+
if (f.eval(input).asInstanceOf[Boolean]) {
263+
buffer += elementVar.value.get
264+
}
265+
i += 1
266+
}
267+
new GenericArrayData(buffer)
268+
}
269+
}
270+
271+
override def prettyName: String = "filter"
272+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
5454
ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f))
5555
}
5656

57+
def filter(expr: Expression, f: Expression => Expression): Expression = {
58+
val at = expr.dataType.asInstanceOf[ArrayType]
59+
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
60+
}
61+
5762
test("ArrayTransform") {
5863
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
5964
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
@@ -94,4 +99,36 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
9499
checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)),
95100
Seq("[1, 3, 5]", null, "[4, 6]"))
96101
}
102+
103+
test("ArrayFilter") {
104+
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
105+
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
106+
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))
107+
108+
val isEven: Expression => Expression = x => x % 2 === 0
109+
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
110+
111+
checkEvaluation(filter(ai0, isEven), Seq(2))
112+
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
113+
checkEvaluation(filter(ai1, isEven), Seq.empty)
114+
checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3))
115+
checkEvaluation(filter(ain, isEven), null)
116+
checkEvaluation(filter(ain, isNullOrOdd), null)
117+
118+
val as0 =
119+
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
120+
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
121+
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
122+
123+
val startsWithA: Expression => Expression = x => x.startsWith("a")
124+
125+
checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2"))
126+
checkEvaluation(filter(as1, startsWithA), Seq("a"))
127+
checkEvaluation(filter(asn, startsWithA), null)
128+
129+
val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
130+
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
131+
checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
132+
Seq(Seq(1, 3), null, Seq(5)))
133+
}
97134
}

sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,12 @@ select transform(ys, 0) as v from nested;
2424

2525
-- Transform a null array
2626
select transform(cast(null as array<int>), x -> x + 1) as v;
27+
28+
-- Filter.
29+
select filter(ys, y -> y > 30) as v from nested;
30+
31+
-- Filter a null array
32+
select filter(cast(null as array<int>), y -> true) as v;
33+
34+
-- Filter nested arrays
35+
select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested;

sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 8
2+
-- Number of queries: 11
33

44

55
-- !query 0
@@ -79,3 +79,31 @@ select transform(cast(null as array<int>), x -> x + 1) as v
7979
struct<v:array<int>>
8080
-- !query 7 output
8181
NULL
82+
83+
84+
-- !query 8
85+
select filter(ys, y -> y > 30) as v from nested
86+
-- !query 8 schema
87+
struct<v:array<int>>
88+
-- !query 8 output
89+
[32,97]
90+
[77]
91+
[]
92+
93+
94+
-- !query 9
95+
select filter(cast(null as array<int>), y -> true) as v
96+
-- !query 9 schema
97+
struct<v:array<int>>
98+
-- !query 9 output
99+
NULL
100+
101+
102+
-- !query 10
103+
select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested
104+
-- !query 10 schema
105+
struct<v:array<array<int>>>
106+
-- !query 10 output
107+
[[96,65],[]]
108+
[[99],[123],[]]
109+
[[]]

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
18001800
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
18011801
}
18021802

1803+
test("filter function - array for primitive type not containing null") {
1804+
val df = Seq(
1805+
Seq(1, 9, 8, 7),
1806+
Seq(5, 8, 9, 7, 2),
1807+
Seq.empty,
1808+
null
1809+
).toDF("i")
1810+
1811+
def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
1812+
checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"),
1813+
Seq(
1814+
Row(Seq(8)),
1815+
Row(Seq(8, 2)),
1816+
Row(Seq.empty),
1817+
Row(null)))
1818+
}
1819+
1820+
// Test with local relation, the Project will be evaluated without codegen
1821+
testArrayOfPrimitiveTypeNotContainsNull()
1822+
// Test with cached relation, the Project will be evaluated with codegen
1823+
df.cache()
1824+
testArrayOfPrimitiveTypeNotContainsNull()
1825+
}
1826+
1827+
test("filter function - array for primitive type containing null") {
1828+
val df = Seq[Seq[Integer]](
1829+
Seq(1, 9, 8, null, 7),
1830+
Seq(5, null, 8, 9, 7, 2),
1831+
Seq.empty,
1832+
null
1833+
).toDF("i")
1834+
1835+
def testArrayOfPrimitiveTypeContainsNull(): Unit = {
1836+
checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"),
1837+
Seq(
1838+
Row(Seq(8)),
1839+
Row(Seq(8, 2)),
1840+
Row(Seq.empty),
1841+
Row(null)))
1842+
}
1843+
1844+
// Test with local relation, the Project will be evaluated without codegen
1845+
testArrayOfPrimitiveTypeContainsNull()
1846+
// Test with cached relation, the Project will be evaluated with codegen
1847+
df.cache()
1848+
testArrayOfPrimitiveTypeContainsNull()
1849+
}
1850+
1851+
test("filter function - array for non-primitive type") {
1852+
val df = Seq(
1853+
Seq("c", "a", "b"),
1854+
Seq("b", null, "c", null),
1855+
Seq.empty,
1856+
null
1857+
).toDF("s")
1858+
1859+
def testNonPrimitiveType(): Unit = {
1860+
checkAnswer(df.selectExpr("filter(s, x -> x is not null)"),
1861+
Seq(
1862+
Row(Seq("c", "a", "b")),
1863+
Row(Seq("b", "c")),
1864+
Row(Seq.empty),
1865+
Row(null)))
1866+
}
1867+
1868+
// Test with local relation, the Project will be evaluated without codegen
1869+
testNonPrimitiveType()
1870+
// Test with cached relation, the Project will be evaluated with codegen
1871+
df.cache()
1872+
testNonPrimitiveType()
1873+
}
1874+
1875+
test("filter function - invalid") {
1876+
val df = Seq(
1877+
(Seq("c", "a", "b"), 1),
1878+
(Seq("b", null, "c", null), 2),
1879+
(Seq.empty, 3),
1880+
(null, 4)
1881+
).toDF("s", "i")
1882+
1883+
val ex1 = intercept[AnalysisException] {
1884+
df.selectExpr("filter(s, (x, y) -> x + y)")
1885+
}
1886+
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))
1887+
1888+
val ex2 = intercept[AnalysisException] {
1889+
df.selectExpr("filter(i, x -> x)")
1890+
}
1891+
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
1892+
1893+
val ex3 = intercept[AnalysisException] {
1894+
df.selectExpr("filter(s, x -> x)")
1895+
}
1896+
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
1897+
}
1898+
18031899
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
18041900
import DataFrameFunctionsSuite.CodegenFallbackExpr
18051901
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)