Skip to content

Commit f81fa47

Browse files
mn-mikkeueshin
authored andcommitted
[SPARK-23926][SQL] Extending reverse function to support ArrayType arguments
## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke <mrkAha12346github> Closes apache#21034 from mn-mikke/feature/array-api-reverse-to-master.
1 parent cce4694 commit f81fa47

File tree

8 files changed

+256
-33
lines changed

8 files changed

+256
-33
lines changed

python/pyspark/sql/functions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,6 @@ def hash(*cols):
14141414
'uppercase. Words are delimited by whitespace.',
14151415
'lower': 'Converts a string column to lower case.',
14161416
'upper': 'Converts a string column to upper case.',
1417-
'reverse': 'Reverses the string column and returns it as a new string column.',
14181417
'ltrim': 'Trim the spaces from left end for the specified string value.',
14191418
'rtrim': 'Trim the spaces from right end for the specified string value.',
14201419
'trim': 'Trim the spaces from both ends for the specified string column.',
@@ -2128,6 +2127,25 @@ def sort_array(col, asc=True):
21282127
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
21292128

21302129

2130+
@since(1.5)
2131+
@ignore_unicode_prefix
2132+
def reverse(col):
2133+
"""
2134+
Collection function: returns a reversed string or an array with reverse order of elements.
2135+
2136+
:param col: name of column or expression
2137+
2138+
>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
2139+
>>> df.select(reverse(df.data).alias('s')).collect()
2140+
[Row(s=u'LQS krapS')]
2141+
>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
2142+
>>> df.select(reverse(df.data).alias('r')).collect()
2143+
[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
2144+
"""
2145+
sc = SparkContext._active_spark_context
2146+
return Column(sc._jvm.functions.reverse(_to_java_column(col)))
2147+
2148+
21312149
@since(2.3)
21322150
def map_keys(col):
21332151
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ object FunctionRegistry {
336336
expression[RegExpReplace]("regexp_replace"),
337337
expression[StringRepeat]("repeat"),
338338
expression[StringReplace]("replace"),
339-
expression[StringReverse]("reverse"),
340339
expression[RLike]("rlike"),
341340
expression[StringRPad]("rpad"),
342341
expression[StringTrimRight]("rtrim"),
@@ -411,6 +410,7 @@ object FunctionRegistry {
411410
expression[SortArray]("sort_array"),
412411
expression[ArrayMin]("array_min"),
413412
expression[ArrayMax]("array_max"),
413+
expression[Reverse]("reverse"),
414414
CreateStruct.registryEntry,
415415

416416
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
2525
import org.apache.spark.sql.types._
26+
import org.apache.spark.unsafe.types.UTF8String
2627

2728
/**
2829
* Given an array or map, returns its size. Returns -1 if null.
@@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
212213
override def prettyName: String = "sort_array"
213214
}
214215

216+
/**
217+
* Returns a reversed string or an array with reverse order of elements.
218+
*/
219+
@ExpressionDescription(
220+
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
221+
examples = """
222+
Examples:
223+
> SELECT _FUNC_('Spark SQL');
224+
LQS krapS
225+
> SELECT _FUNC_(array(2, 1, 4, 3));
226+
[3, 4, 1, 2]
227+
""",
228+
since = "1.5.0",
229+
note = "Reverse logic for arrays is available since 2.4.0."
230+
)
231+
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
232+
233+
// Input types are utilized by type coercion in ImplicitTypeCasts.
234+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
235+
236+
override def dataType: DataType = child.dataType
237+
238+
lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
239+
240+
override def nullSafeEval(input: Any): Any = input match {
241+
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
242+
case s: UTF8String => s.reverse()
243+
}
244+
245+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
246+
nullSafeCodeGen(ctx, ev, c => dataType match {
247+
case _: StringType => stringCodeGen(ev, c)
248+
case _: ArrayType => arrayCodeGen(ctx, ev, c)
249+
})
250+
}
251+
252+
private def stringCodeGen(ev: ExprCode, childName: String): String = {
253+
s"${ev.value} = ($childName).reverse();"
254+
}
255+
256+
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
257+
val length = ctx.freshName("length")
258+
val javaElementType = CodeGenerator.javaType(elementType)
259+
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
260+
261+
val initialization = if (isPrimitiveType) {
262+
s"$childName.copy()"
263+
} else {
264+
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
265+
}
266+
267+
val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
268+
269+
val swapAssigments = if (isPrimitiveType) {
270+
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
271+
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
272+
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
273+
|boolean isNullAtL = ${ev.value}.isNullAt(l);
274+
|if(!isNullAtK) {
275+
| $javaElementType el = ${getCall("k")};
276+
| if(!isNullAtL) {
277+
| ${ev.value}.$setFunc(k, ${getCall("l")});
278+
| } else {
279+
| ${ev.value}.setNullAt(k);
280+
| }
281+
| ${ev.value}.$setFunc(l, el);
282+
|} else if (!isNullAtL) {
283+
| ${ev.value}.$setFunc(k, ${getCall("l")});
284+
| ${ev.value}.setNullAt(l);
285+
|}""".stripMargin
286+
} else {
287+
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
288+
}
289+
290+
s"""
291+
|final int $length = $childName.numElements();
292+
|${ev.value} = $initialization;
293+
|for(int k = 0; k < $numberOfIterations; k++) {
294+
| int l = $length - k - 1;
295+
| $swapAssigments
296+
|}
297+
""".stripMargin
298+
}
299+
300+
override def prettyName: String = "reverse"
301+
}
302+
215303
/**
216304
* Checks if the array (left) has the element (right)
217305
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression)
15041504
}
15051505
}
15061506

1507-
/**
1508-
* Returns the reversed given string.
1509-
*/
1510-
@ExpressionDescription(
1511-
usage = "_FUNC_(str) - Returns the reversed given string.",
1512-
examples = """
1513-
Examples:
1514-
> SELECT _FUNC_('Spark SQL');
1515-
LQS krapS
1516-
""")
1517-
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
1518-
override def convert(v: UTF8String): UTF8String = v.reverse()
1519-
1520-
override def prettyName: String = "reverse"
1521-
1522-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1523-
defineCodeGen(ctx, ev, c => s"($c).reverse()")
1524-
}
1525-
}
1526-
15271507
/**
15281508
* Returns a string consisting of n spaces.
15291509
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
125125
checkEvaluation(
126126
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
127127
}
128+
129+
test("Reverse") {
130+
// Primitive-type elements
131+
val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
132+
val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
133+
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
134+
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
135+
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
136+
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
137+
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
138+
val ai7 = Literal.create(null, ArrayType(IntegerType))
139+
140+
checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
141+
checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
142+
checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
143+
checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
144+
checkEvaluation(Reverse(ai4), Seq(null, null, null))
145+
checkEvaluation(Reverse(ai5), Seq(1))
146+
checkEvaluation(Reverse(ai6), Seq.empty)
147+
checkEvaluation(Reverse(ai7), null)
148+
149+
// Non-primitive-type elements
150+
val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
151+
val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
152+
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
153+
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
154+
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
155+
val as5 = Literal.create(Seq("a"), ArrayType(StringType))
156+
val as6 = Literal.create(Seq.empty, ArrayType(StringType))
157+
val as7 = Literal.create(null, ArrayType(StringType))
158+
val aa = Literal.create(
159+
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
160+
ArrayType(ArrayType(StringType)))
161+
162+
checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
163+
checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
164+
checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
165+
checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
166+
checkEvaluation(Reverse(as4), Seq(null, null, null))
167+
checkEvaluation(Reverse(as5), Seq("a"))
168+
checkEvaluation(Reverse(as6), Seq.empty)
169+
checkEvaluation(Reverse(as7), null)
170+
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
171+
}
128172
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
629629
test("REVERSE") {
630630
val s = 'a.string.at(0)
631631
val row1 = create_row("abccc")
632-
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
633-
checkEvaluation(StringReverse(s), "cccba", row1)
634-
checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
632+
checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
633+
checkEvaluation(Reverse(s), "cccba", row1)
634+
checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
635635
}
636636

637637
test("SPACE") {

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,14 +2464,6 @@ object functions {
24642464
StringRepeat(str.expr, lit(n).expr)
24652465
}
24662466

2467-
/**
2468-
* Reverses the string column and returns it as a new string column.
2469-
*
2470-
* @group string_funcs
2471-
* @since 1.5.0
2472-
*/
2473-
def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }
2474-
24752467
/**
24762468
* Trim the spaces from right end for the specified string value.
24772469
*
@@ -3316,6 +3308,13 @@ object functions {
33163308
*/
33173309
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
33183310

3311+
/**
3312+
* Returns a reversed string or an array with reverse order of elements.
3313+
* @group collection_funcs
3314+
* @since 1.5.0
3315+
*/
3316+
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
3317+
33193318
/**
33203319
* Returns an unordered array containing the keys of the map.
33213320
* @group collection_funcs

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
441441
checkAnswer(df.selectExpr("array_max(a)"), answer)
442442
}
443443

444+
test("reverse function") {
445+
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on
446+
447+
// String test cases
448+
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
449+
450+
checkAnswer(
451+
oneRowDF.select(reverse('s)),
452+
Seq(Row("krapS"))
453+
)
454+
checkAnswer(
455+
oneRowDF.selectExpr("reverse(s)"),
456+
Seq(Row("krapS"))
457+
)
458+
checkAnswer(
459+
oneRowDF.select(reverse('i)),
460+
Seq(Row("5123"))
461+
)
462+
checkAnswer(
463+
oneRowDF.selectExpr("reverse(i)"),
464+
Seq(Row("5123"))
465+
)
466+
checkAnswer(
467+
oneRowDF.selectExpr("reverse(null)"),
468+
Seq(Row(null))
469+
)
470+
471+
// Array test cases (primitive-type elements)
472+
val idf = Seq(
473+
Seq(1, 9, 8, 7),
474+
Seq(5, 8, 9, 7, 2),
475+
Seq.empty,
476+
null
477+
).toDF("i")
478+
479+
checkAnswer(
480+
idf.select(reverse('i)),
481+
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
482+
)
483+
checkAnswer(
484+
idf.filter(dummyFilter('i)).select(reverse('i)),
485+
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
486+
)
487+
checkAnswer(
488+
idf.selectExpr("reverse(i)"),
489+
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
490+
)
491+
checkAnswer(
492+
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
493+
Seq(Row(Seq(null, 2, null, 1)))
494+
)
495+
checkAnswer(
496+
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
497+
Seq(Row(Seq(null, 2, null, 1)))
498+
)
499+
500+
// Array test cases (non-primitive-type elements)
501+
val sdf = Seq(
502+
Seq("c", "a", "b"),
503+
Seq("b", null, "c", null),
504+
Seq.empty,
505+
null
506+
).toDF("s")
507+
508+
checkAnswer(
509+
sdf.select(reverse('s)),
510+
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
511+
)
512+
checkAnswer(
513+
sdf.filter(dummyFilter('s)).select(reverse('s)),
514+
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
515+
)
516+
checkAnswer(
517+
sdf.selectExpr("reverse(s)"),
518+
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
519+
)
520+
checkAnswer(
521+
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
522+
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
523+
)
524+
checkAnswer(
525+
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
526+
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
527+
)
528+
529+
// Error test cases
530+
intercept[AnalysisException] {
531+
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
532+
}
533+
intercept[AnalysisException] {
534+
oneRowDF.selectExpr("reverse(map(1, 'a'))")
535+
}
536+
}
537+
444538
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
445539
import DataFrameFunctionsSuite.CodegenFallbackExpr
446540
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)