Skip to content

Commit ef6c839

Browse files
huizhiluueshin
authored andcommitted
[SPARK-23928][SQL] Add shuffle collection function.
## What changes were proposed in this pull request? This PR adds a new collection function: shuffle. It generates a random permutation of the given array. This implementation uses the "inside-out" version of Fisher-Yates algorithm. ## How was this patch tested? New tests are added to CollectionExpressionsSuite.scala and DataFrameFunctionsSuite.scala. Author: Takuya UESHIN <[email protected]> Author: pkuwm <[email protected]> Closes apache#21802 from ueshin/issues/SPARK-23928/shuffle.
1 parent 21fcac1 commit ef6c839

File tree

8 files changed

+317
-3
lines changed

8 files changed

+317
-3
lines changed

python/pyspark/sql/functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,23 @@ def array_sort(col):
23822382
return Column(sc._jvm.functions.array_sort(_to_java_column(col)))
23832383

23842384

2385+
@since(2.4)
2386+
def shuffle(col):
2387+
"""
2388+
Collection function: Generates a random permutation of the given array.
2389+
2390+
.. note:: The function is non-deterministic.
2391+
2392+
:param col: name of column or expression
2393+
2394+
>>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data'])
2395+
>>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP
2396+
[Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])]
2397+
"""
2398+
sc = SparkContext._active_spark_context
2399+
return Column(sc._jvm.functions.shuffle(_to_java_column(col)))
2400+
2401+
23852402
@since(1.5)
23862403
@ignore_unicode_prefix
23872404
def reverse(col):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class Analyzer(
181181
TimeWindowing ::
182182
ResolveInlineTables(conf) ::
183183
ResolveTimeZone(conf) ::
184-
ResolvedUuidExpressions ::
184+
ResolveRandomSeed ::
185185
TypeCoercion.typeCoercionRules(conf) ++
186186
extendedResolutionRules : _*),
187187
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -2100,15 +2100,16 @@ class Analyzer(
21002100
}
21012101

21022102
/**
2103-
* Set the seed for random number generation in Uuid expressions.
2103+
* Set the seed for random number generation.
21042104
*/
2105-
object ResolvedUuidExpressions extends Rule[LogicalPlan] {
2105+
object ResolveRandomSeed extends Rule[LogicalPlan] {
21062106
private lazy val random = new Random()
21072107

21082108
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
21092109
case p if p.resolved => p
21102110
case p => p transformExpressionsUp {
21112111
case Uuid(None) => Uuid(Some(random.nextLong()))
2112+
case Shuffle(child, None) => Shuffle(child, Some(random.nextLong()))
21122113
}
21132114
}
21142115
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ object FunctionRegistry {
429429
expression[Size]("cardinality"),
430430
expression[ArraysZip]("arrays_zip"),
431431
expression[SortArray]("sort_array"),
432+
expression[Shuffle]("shuffle"),
432433
expression[ArrayMin]("array_min"),
433434
expression[ArrayMax]("array_max"),
434435
expression[Reverse]("reverse"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,112 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi
12031203
override def prettyName: String = "array_sort"
12041204
}
12051205

1206+
/**
1207+
* Returns a random permutation of the given array.
1208+
*/
1209+
@ExpressionDescription(
1210+
usage = "_FUNC_(array) - Returns a random permutation of the given array.",
1211+
examples = """
1212+
Examples:
1213+
> SELECT _FUNC_(array(1, 20, 3, 5));
1214+
[3, 1, 5, 20]
1215+
> SELECT _FUNC_(array(1, 20, null, 3));
1216+
[20, null, 3, 1]
1217+
""",
1218+
note = "The function is non-deterministic.",
1219+
since = "2.4.0")
1220+
case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
1221+
extends UnaryExpression with ExpectsInputTypes with Stateful {
1222+
1223+
def this(child: Expression) = this(child, None)
1224+
1225+
override lazy val resolved: Boolean =
1226+
childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined
1227+
1228+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
1229+
1230+
override def dataType: DataType = child.dataType
1231+
1232+
@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
1233+
1234+
@transient private[this] var random: RandomIndicesGenerator = _
1235+
1236+
override protected def initializeInternal(partitionIndex: Int): Unit = {
1237+
random = RandomIndicesGenerator(randomSeed.get + partitionIndex)
1238+
}
1239+
1240+
override protected def evalInternal(input: InternalRow): Any = {
1241+
val value = child.eval(input)
1242+
if (value == null) {
1243+
null
1244+
} else {
1245+
val source = value.asInstanceOf[ArrayData]
1246+
val numElements = source.numElements()
1247+
val indices = random.getNextIndices(numElements)
1248+
new GenericArrayData(indices.map(source.get(_, elementType)))
1249+
}
1250+
}
1251+
1252+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1253+
nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c))
1254+
}
1255+
1256+
private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
1257+
val randomClass = classOf[RandomIndicesGenerator].getName
1258+
1259+
val rand = ctx.addMutableState(randomClass, "rand", forceInline = true)
1260+
ctx.addPartitionInitializationStatement(
1261+
s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);")
1262+
1263+
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
1264+
1265+
val numElements = ctx.freshName("numElements")
1266+
val arrayData = ctx.freshName("arrayData")
1267+
1268+
val initialization = if (isPrimitiveType) {
1269+
ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.")
1270+
} else {
1271+
val arrayDataClass = classOf[GenericArrayData].getName()
1272+
s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);"
1273+
}
1274+
1275+
val indices = ctx.freshName("indices")
1276+
val i = ctx.freshName("i")
1277+
1278+
val getValue = CodeGenerator.getValue(childName, elementType, s"$indices[$i]")
1279+
1280+
val setFunc = if (isPrimitiveType) {
1281+
s"set${CodeGenerator.primitiveTypeName(elementType)}"
1282+
} else {
1283+
"update"
1284+
}
1285+
1286+
val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) {
1287+
s"""
1288+
|if ($childName.isNullAt($indices[$i])) {
1289+
| $arrayData.setNullAt($i);
1290+
|} else {
1291+
| $arrayData.$setFunc($i, $getValue);
1292+
|}
1293+
""".stripMargin
1294+
} else {
1295+
s"$arrayData.$setFunc($i, $getValue);"
1296+
}
1297+
1298+
s"""
1299+
|int $numElements = $childName.numElements();
1300+
|int[] $indices = $rand.getNextIndices($numElements);
1301+
|$initialization
1302+
|for (int $i = 0; $i < $numElements; $i++) {
1303+
| $assignment
1304+
|}
1305+
|${ev.value} = $arrayData;
1306+
""".stripMargin
1307+
}
1308+
1309+
override def freshCopy(): Shuffle = Shuffle(child, randomSeed)
1310+
}
1311+
12061312
/**
12071313
* Returns a reversed string or an array with reverse order of elements.
12081314
*/
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import org.apache.commons.math3.random.MersenneTwister
21+
22+
/**
23+
* This class is used to generate a random indices of given length.
24+
*
25+
* This implementation uses the "inside-out" version of Fisher-Yates algorithm.
26+
* Reference:
27+
* https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_%22inside-out%22_algorithm
28+
*/
29+
case class RandomIndicesGenerator(randomSeed: Long) {
30+
private val random = new MersenneTwister(randomSeed)
31+
32+
def getNextIndices(length: Int): Array[Int] = {
33+
val indices = new Array[Int](length)
34+
var i = 0
35+
while (i < length) {
36+
val j = random.nextInt(i + 1)
37+
if (j != i) {
38+
indices(i) = indices(j)
39+
}
40+
indices(j) = i
41+
i += 1
42+
}
43+
indices
44+
}
45+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.sql.{Date, Timestamp}
2121
import java.util.TimeZone
2222

23+
import scala.util.Random
24+
2325
import org.apache.spark.SparkFunSuite
2426
import org.apache.spark.sql.Row
2527
import org.apache.spark.sql.catalyst.InternalRow
@@ -1434,4 +1436,71 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
14341436
assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
14351437
assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
14361438
}
1439+
1440+
test("Shuffle") {
1441+
// Primitive-type elements
1442+
val ai0 = Literal.create(Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, containsNull = false))
1443+
val ai1 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
1444+
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType, containsNull = true))
1445+
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType, containsNull = true))
1446+
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType, containsNull = true))
1447+
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType, containsNull = false))
1448+
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType, containsNull = false))
1449+
val ai7 = Literal.create(null, ArrayType(IntegerType, containsNull = true))
1450+
1451+
checkEvaluation(Shuffle(ai0, Some(0)), Seq(4, 1, 2, 3, 5))
1452+
checkEvaluation(Shuffle(ai1, Some(0)), Seq(3, 1, 2))
1453+
checkEvaluation(Shuffle(ai2, Some(0)), Seq(3, null, 1, null))
1454+
checkEvaluation(Shuffle(ai3, Some(0)), Seq(null, 2, null, 4))
1455+
checkEvaluation(Shuffle(ai4, Some(0)), Seq(null, null, null))
1456+
checkEvaluation(Shuffle(ai5, Some(0)), Seq(1))
1457+
checkEvaluation(Shuffle(ai6, Some(0)), Seq.empty)
1458+
checkEvaluation(Shuffle(ai7, Some(0)), null)
1459+
1460+
// Non-primitive-type elements
1461+
val as0 = Literal.create(Seq("a", "b", "c", "d"), ArrayType(StringType, containsNull = false))
1462+
val as1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
1463+
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType, containsNull = true))
1464+
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType, containsNull = true))
1465+
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType, containsNull = true))
1466+
val as5 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = false))
1467+
val as6 = Literal.create(Seq.empty, ArrayType(StringType, containsNull = false))
1468+
val as7 = Literal.create(null, ArrayType(StringType, containsNull = true))
1469+
val aa = Literal.create(
1470+
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
1471+
ArrayType(ArrayType(StringType)))
1472+
1473+
checkEvaluation(Shuffle(as0, Some(0)), Seq("d", "a", "b", "c"))
1474+
checkEvaluation(Shuffle(as1, Some(0)), Seq("c", "a", "b"))
1475+
checkEvaluation(Shuffle(as2, Some(0)), Seq("c", null, "a", null))
1476+
checkEvaluation(Shuffle(as3, Some(0)), Seq(null, "b", null, "d"))
1477+
checkEvaluation(Shuffle(as4, Some(0)), Seq(null, null, null))
1478+
checkEvaluation(Shuffle(as5, Some(0)), Seq("a"))
1479+
checkEvaluation(Shuffle(as6, Some(0)), Seq.empty)
1480+
checkEvaluation(Shuffle(as7, Some(0)), null)
1481+
checkEvaluation(Shuffle(aa, Some(0)), Seq(Seq("e"), Seq("a", "b"), Seq("c", "d")))
1482+
1483+
val r = new Random()
1484+
val seed1 = Some(r.nextLong())
1485+
assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) ===
1486+
evaluateWithoutCodegen(Shuffle(ai0, seed1)))
1487+
assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) ===
1488+
evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)))
1489+
assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) ===
1490+
evaluateWithUnsafeProjection(Shuffle(ai0, seed1)))
1491+
1492+
val seed2 = Some(r.nextLong())
1493+
assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) !==
1494+
evaluateWithoutCodegen(Shuffle(ai0, seed2)))
1495+
assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) !==
1496+
evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed2)))
1497+
assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !==
1498+
evaluateWithUnsafeProjection(Shuffle(ai0, seed2)))
1499+
1500+
val shuffle = Shuffle(ai0, seed1)
1501+
assert(shuffle.fastEquals(shuffle))
1502+
assert(!shuffle.fastEquals(Shuffle(ai0, seed1)))
1503+
assert(!shuffle.fastEquals(shuffle.freshCopy()))
1504+
assert(!shuffle.fastEquals(Shuffle(ai0, seed2)))
1505+
}
14371506
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3545,6 +3545,16 @@ object functions {
35453545
*/
35463546
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
35473547

3548+
/**
3549+
* Returns a random permutation of the given array.
3550+
*
3551+
* @note The function is non-deterministic.
3552+
*
3553+
* @group collection_funcs
3554+
* @since 2.4.0
3555+
*/
3556+
def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) }
3557+
35483558
/**
35493559
* Returns a reversed string or an array with reverse order of elements.
35503560
* @group collection_funcs

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,71 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
15131513
)
15141514
}
15151515

1516+
// Shuffle expressions should produce same results at retries in the same DataFrame.
1517+
private def checkShuffleResult(df: DataFrame): Unit = {
1518+
checkAnswer(df, df.collect())
1519+
}
1520+
1521+
test("shuffle function - array for primitive type not containing null") {
1522+
val idfNotContainsNull = Seq(
1523+
Seq(1, 9, 8, 7),
1524+
Seq(5, 8, 9, 7, 2),
1525+
Seq.empty,
1526+
null
1527+
).toDF("i")
1528+
1529+
def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
1530+
checkShuffleResult(idfNotContainsNull.select(shuffle('i)))
1531+
checkShuffleResult(idfNotContainsNull.selectExpr("shuffle(i)"))
1532+
}
1533+
1534+
// Test with local relation, the Project will be evaluated without codegen
1535+
testArrayOfPrimitiveTypeNotContainsNull()
1536+
// Test with cached relation, the Project will be evaluated with codegen
1537+
idfNotContainsNull.cache()
1538+
testArrayOfPrimitiveTypeNotContainsNull()
1539+
}
1540+
1541+
test("shuffle function - array for primitive type containing null") {
1542+
val idfContainsNull = Seq[Seq[Integer]](
1543+
Seq(1, 9, 8, null, 7),
1544+
Seq(null, 5, 8, 9, 7, 2),
1545+
Seq.empty,
1546+
null
1547+
).toDF("i")
1548+
1549+
def testArrayOfPrimitiveTypeContainsNull(): Unit = {
1550+
checkShuffleResult(idfContainsNull.select(shuffle('i)))
1551+
checkShuffleResult(idfContainsNull.selectExpr("shuffle(i)"))
1552+
}
1553+
1554+
// Test with local relation, the Project will be evaluated without codegen
1555+
testArrayOfPrimitiveTypeContainsNull()
1556+
// Test with cached relation, the Project will be evaluated with codegen
1557+
idfContainsNull.cache()
1558+
testArrayOfPrimitiveTypeContainsNull()
1559+
}
1560+
1561+
test("shuffle function - array for non-primitive type") {
1562+
val sdf = Seq(
1563+
Seq("c", "a", "b"),
1564+
Seq("b", null, "c", null),
1565+
Seq.empty,
1566+
null
1567+
).toDF("s")
1568+
1569+
def testNonPrimitiveType(): Unit = {
1570+
checkShuffleResult(sdf.select(shuffle('s)))
1571+
checkShuffleResult(sdf.selectExpr("shuffle(s)"))
1572+
}
1573+
1574+
// Test with local relation, the Project will be evaluated without codegen
1575+
testNonPrimitiveType()
1576+
// Test with cached relation, the Project will be evaluated with codegen
1577+
sdf.cache()
1578+
testNonPrimitiveType()
1579+
}
1580+
15161581
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
15171582
import DataFrameFunctionsSuite.CodegenFallbackExpr
15181583
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)