Skip to content

Commit 9de11d3

Browse files
huaxingaoueshin
authored andcommitted
[SPARK-23912][SQL] add array_distinct
## What changes were proposed in this pull request? Add array_distinct to remove duplicate value from the array. ## How was this patch tested? Add unit tests Author: Huaxin Gao <[email protected]> Closes apache#21050 from huaxingao/spark-23912.
1 parent 15747cf commit 9de11d3

File tree

6 files changed

+368
-0
lines changed

6 files changed

+368
-0
lines changed

python/pyspark/sql/functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,20 @@ def array_remove(col, element):
19991999
return Column(sc._jvm.functions.array_remove(_to_java_column(col), element))
20002000

20012001

2002+
@since(2.4)
2003+
def array_distinct(col):
2004+
"""
2005+
Collection function: removes duplicate values from the array.
2006+
:param col: name of column or expression
2007+
2008+
>>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])
2009+
>>> df.select(array_distinct(df.data)).collect()
2010+
[Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
2011+
"""
2012+
sc = SparkContext._active_spark_context
2013+
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
2014+
2015+
20022016
@since(1.4)
20032017
def explode(col):
20042018
"""Returns a new row for each element in the given array or map.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ object FunctionRegistry {
433433
expression[Flatten]("flatten"),
434434
expression[ArrayRepeat]("array_repeat"),
435435
expression[ArrayRemove]("array_remove"),
436+
expression[ArrayDistinct]("array_distinct"),
436437
CreateStruct.registryEntry,
437438

438439
// mask functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.Platform
3232
import org.apache.spark.unsafe.array.ByteArrayMethods
3333
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
34+
import org.apache.spark.util.collection.OpenHashSet
3435

3536
/**
3637
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
@@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression)
23552356

23562357
override def prettyName: String = "array_remove"
23572358
}
2359+
2360+
/**
2361+
* Removes duplicate values from the array.
2362+
*/
2363+
@ExpressionDescription(
2364+
usage = "_FUNC_(array) - Removes duplicate values from the array.",
2365+
examples = """
2366+
Examples:
2367+
> SELECT _FUNC_(array(1, 2, 3, null, 3));
2368+
[1,2,3,null]
2369+
""", since = "2.4.0")
2370+
case class ArrayDistinct(child: Expression)
2371+
extends UnaryExpression with ExpectsInputTypes {
2372+
2373+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
2374+
2375+
override def dataType: DataType = child.dataType
2376+
2377+
@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
2378+
2379+
@transient private lazy val ordering: Ordering[Any] =
2380+
TypeUtils.getInterpretedOrdering(elementType)
2381+
2382+
override def checkInputDataTypes(): TypeCheckResult = {
2383+
super.checkInputDataTypes() match {
2384+
case f: TypeCheckResult.TypeCheckFailure => f
2385+
case TypeCheckResult.TypeCheckSuccess =>
2386+
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
2387+
}
2388+
}
2389+
2390+
@transient private lazy val elementTypeSupportEquals = elementType match {
2391+
case BinaryType => false
2392+
case _: AtomicType => true
2393+
case _ => false
2394+
}
2395+
2396+
override def nullSafeEval(array: Any): Any = {
2397+
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
2398+
if (elementTypeSupportEquals) {
2399+
new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
2400+
} else {
2401+
var foundNullElement = false
2402+
var pos = 0
2403+
for (i <- 0 until data.length) {
2404+
if (data(i) == null) {
2405+
if (!foundNullElement) {
2406+
foundNullElement = true
2407+
pos = pos + 1
2408+
}
2409+
} else {
2410+
var j = 0
2411+
var done = false
2412+
while (j <= i && !done) {
2413+
if (data(j) != null && ordering.equiv(data(j), data(i))) {
2414+
done = true
2415+
}
2416+
j = j + 1
2417+
}
2418+
if (i == j - 1) {
2419+
pos = pos + 1
2420+
}
2421+
}
2422+
}
2423+
new GenericArrayData(data.slice(0, pos))
2424+
}
2425+
}
2426+
2427+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
2428+
nullSafeCodeGen(ctx, ev, (array) => {
2429+
val i = ctx.freshName("i")
2430+
val j = ctx.freshName("j")
2431+
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
2432+
val getValue1 = CodeGenerator.getValue(array, elementType, i)
2433+
val getValue2 = CodeGenerator.getValue(array, elementType, j)
2434+
val foundNullElement = ctx.freshName("foundNullElement")
2435+
val openHashSet = classOf[OpenHashSet[_]].getName
2436+
val hs = ctx.freshName("hs")
2437+
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
2438+
if (elementTypeSupportEquals) {
2439+
s"""
2440+
|int $sizeOfDistinctArray = 0;
2441+
|boolean $foundNullElement = false;
2442+
|$openHashSet $hs = new $openHashSet($classTag);
2443+
|for (int $i = 0; $i < $array.numElements(); $i ++) {
2444+
| if ($array.isNullAt($i)) {
2445+
| $foundNullElement = true;
2446+
| } else {
2447+
| $hs.add($getValue1);
2448+
| }
2449+
|}
2450+
|$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
2451+
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2452+
""".stripMargin
2453+
} else {
2454+
s"""
2455+
|int $sizeOfDistinctArray = 0;
2456+
|boolean $foundNullElement = false;
2457+
|for (int $i = 0; $i < $array.numElements(); $i ++) {
2458+
| if ($array.isNullAt($i)) {
2459+
| if (!($foundNullElement)) {
2460+
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2461+
| $foundNullElement = true;
2462+
| }
2463+
| } else {
2464+
| int $j;
2465+
| for ($j = 0; $j < $i; $j ++) {
2466+
| if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) {
2467+
| break;
2468+
| }
2469+
| }
2470+
| if ($i == $j) {
2471+
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
2472+
| }
2473+
| }
2474+
|}
2475+
|
2476+
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
2477+
""".stripMargin
2478+
}
2479+
})
2480+
}
2481+
2482+
private def setNull(
2483+
isPrimitive: Boolean,
2484+
foundNullElement: String,
2485+
distinctArray: String,
2486+
pos: String): String = {
2487+
val setNullValue =
2488+
if (!isPrimitive) {
2489+
s"$distinctArray[$pos] = null";
2490+
} else {
2491+
s"$distinctArray.setNullAt($pos)";
2492+
}
2493+
2494+
s"""
2495+
|if (!($foundNullElement)) {
2496+
| $setNullValue;
2497+
| $pos = $pos + 1;
2498+
| $foundNullElement = true;
2499+
|}
2500+
""".stripMargin
2501+
}
2502+
2503+
private def setNotNullValue(isPrimitive: Boolean,
2504+
distinctArray: String,
2505+
pos: String,
2506+
getValue1: String,
2507+
primitiveValueTypeName: String): String = {
2508+
if (!isPrimitive) {
2509+
s"$distinctArray[$pos] = $getValue1";
2510+
} else {
2511+
s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
2512+
}
2513+
}
2514+
2515+
private def setValueForFastEval(
2516+
isPrimitive: Boolean,
2517+
hs: String,
2518+
distinctArray: String,
2519+
pos: String,
2520+
getValue1: String,
2521+
primitiveValueTypeName: String): String = {
2522+
val setValue = setNotNullValue(isPrimitive,
2523+
distinctArray, pos, getValue1, primitiveValueTypeName)
2524+
s"""
2525+
|if (!($hs.contains($getValue1))) {
2526+
| $hs.add($getValue1);
2527+
| $setValue;
2528+
| $pos = $pos + 1;
2529+
|}
2530+
""".stripMargin
2531+
}
2532+
2533+
private def setValueForBruteForceEval(
2534+
isPrimitive: Boolean,
2535+
i: String,
2536+
j: String,
2537+
inputArray: String,
2538+
distinctArray: String,
2539+
pos: String,
2540+
getValue1: String,
2541+
isEqual: String,
2542+
primitiveValueTypeName: String): String = {
2543+
val setValue = setNotNullValue(isPrimitive,
2544+
distinctArray, pos, getValue1, primitiveValueTypeName)
2545+
s"""
2546+
|int $j;
2547+
|for ($j = 0; $j < $i; $j ++) {
2548+
| if (!$inputArray.isNullAt($j) && $isEqual) {
2549+
| break;
2550+
| }
2551+
|}
2552+
|if ($i == $j) {
2553+
| $setValue;
2554+
| $pos = $pos + 1;
2555+
|}
2556+
""".stripMargin
2557+
}
2558+
2559+
def genCodeForResult(
2560+
ctx: CodegenContext,
2561+
ev: ExprCode,
2562+
inputArray: String,
2563+
size: String): String = {
2564+
val distinctArray = ctx.freshName("distinctArray")
2565+
val i = ctx.freshName("i")
2566+
val j = ctx.freshName("j")
2567+
val pos = ctx.freshName("pos")
2568+
val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
2569+
val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
2570+
val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
2571+
val foundNullElement = ctx.freshName("foundNullElement")
2572+
val hs = ctx.freshName("hs")
2573+
val openHashSet = classOf[OpenHashSet[_]].getName
2574+
if (!CodeGenerator.isPrimitiveType(elementType)) {
2575+
val arrayClass = classOf[GenericArrayData].getName
2576+
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
2577+
val setNullForNonPrimitive =
2578+
setNull(false, foundNullElement, distinctArray, pos)
2579+
if (elementTypeSupportEquals) {
2580+
val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "")
2581+
s"""
2582+
|int $pos = 0;
2583+
|Object[] $distinctArray = new Object[$size];
2584+
|boolean $foundNullElement = false;
2585+
|$openHashSet $hs = new $openHashSet($classTag);
2586+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2587+
| if ($inputArray.isNullAt($i)) {
2588+
| $setNullForNonPrimitive;
2589+
| } else {
2590+
| $setValueForFast;
2591+
| }
2592+
|}
2593+
|${ev.value} = new $arrayClass($distinctArray);
2594+
""".stripMargin
2595+
} else {
2596+
val setValueForBruteForce = setValueForBruteForceEval(
2597+
false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
2598+
s"""
2599+
|int $pos = 0;
2600+
|Object[] $distinctArray = new Object[$size];
2601+
|boolean $foundNullElement = false;
2602+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2603+
| if ($inputArray.isNullAt($i)) {
2604+
| $setNullForNonPrimitive;
2605+
| } else {
2606+
| $setValueForBruteForce;
2607+
| }
2608+
|}
2609+
|${ev.value} = new $arrayClass($distinctArray);
2610+
""".stripMargin
2611+
}
2612+
} else {
2613+
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
2614+
val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos)
2615+
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
2616+
val setValueForFast =
2617+
setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName)
2618+
s"""
2619+
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
2620+
|int $pos = 0;
2621+
|boolean $foundNullElement = false;
2622+
|$openHashSet $hs = new $openHashSet($classTag);
2623+
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
2624+
| if ($inputArray.isNullAt($i)) {
2625+
| $setNullForPrimitive;
2626+
| } else {
2627+
| $setValueForFast;
2628+
| }
2629+
|}
2630+
|${ev.value} = $distinctArray;
2631+
""".stripMargin
2632+
}
2633+
}
2634+
2635+
override def prettyName: String = "array_distinct"
2636+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
766766
checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
767767
checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
768768
}
769+
770+
test("Array Distinct") {
771+
val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType))
772+
val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
773+
val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
774+
val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType))
775+
val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType))
776+
val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType))
777+
val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234),
778+
ArrayType(DoubleType))
779+
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
780+
ArrayType(FloatType))
781+
782+
checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
783+
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
784+
checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c"))
785+
checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a"))
786+
checkEvaluation(new ArrayDistinct(a4), Seq(null))
787+
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
788+
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
789+
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
790+
791+
// complex data types
792+
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
793+
Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
794+
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
795+
ArrayType(BinaryType))
796+
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2),
797+
null, Array[Byte](5, 6), null), ArrayType(BinaryType))
798+
799+
checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)))
800+
checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null))
801+
checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null,
802+
Array[Byte](1, 2)))
803+
804+
val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2),
805+
Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType)))
806+
val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
807+
ArrayType(ArrayType(IntegerType)))
808+
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
809+
ArrayType(ArrayType(IntegerType)))
810+
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
811+
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
812+
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
813+
}
769814
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,6 +3189,13 @@ object functions {
31893189
ArrayRemove(column.expr, Literal(element))
31903190
}
31913191

3192+
/**
3193+
* Removes duplicate values from the array.
3194+
* @group collection_funcs
3195+
* @since 2.4.0
3196+
*/
3197+
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
3198+
31923199
/**
31933200
* Creates a new row for each element in the given array or map column.
31943201
*

0 commit comments

Comments
 (0)