Skip to content

Commit 4446a0b

Browse files
kiszkueshin
authored andcommitted
[SPARK-23914][SQL][FOLLOW-UP] refactor ArrayUnion
## What changes were proposed in this pull request? This PR refactors `ArrayUnion` based on [this suggestion](apache#21103 (comment)). 1. Generate optimized code for all of the primitive types except `boolean` 1. Generate code using `ArrayBuilder` or `ArrayBuffer` 1. Leave only a generic path in the interpreted path ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki <[email protected]> Closes apache#21937 from kiszk/SPARK-23914-follow.
1 parent 51bee7a commit 4446a0b

File tree

3 files changed

+153
-217
lines changed

3 files changed

+153
-217
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 126 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -3767,230 +3767,159 @@ object ArraySetLike {
37673767
""",
37683768
since = "2.4.0")
37693769
case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
3770-
with ComplexTypeMergingExpression {
3771-
var hsInt: OpenHashSet[Int] = _
3772-
var hsLong: OpenHashSet[Long] = _
3773-
3774-
def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
3775-
val elem = array.getInt(idx)
3776-
if (!hsInt.contains(elem)) {
3777-
if (resultArray != null) {
3778-
resultArray.setInt(pos, elem)
3779-
}
3780-
hsInt.add(elem)
3781-
true
3782-
} else {
3783-
false
3784-
}
3785-
}
3786-
3787-
def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
3788-
val elem = array.getLong(idx)
3789-
if (!hsLong.contains(elem)) {
3790-
if (resultArray != null) {
3791-
resultArray.setLong(pos, elem)
3792-
}
3793-
hsLong.add(elem)
3794-
true
3795-
} else {
3796-
false
3797-
}
3798-
}
3770+
with ComplexTypeMergingExpression {
37993771

3800-
def evalIntLongPrimitiveType(
3801-
array1: ArrayData,
3802-
array2: ArrayData,
3803-
resultArray: ArrayData,
3804-
isLongType: Boolean): Int = {
3805-
// store elements into resultArray
3806-
var nullElementSize = 0
3807-
var pos = 0
3808-
Seq(array1, array2).foreach { array =>
3809-
var i = 0
3810-
while (i < array.numElements()) {
3811-
val size = if (!isLongType) hsInt.size else hsLong.size
3812-
if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3813-
ArraySetLike.throwUnionLengthOverflowException(size)
3814-
}
3815-
if (array.isNullAt(i)) {
3816-
if (nullElementSize == 0) {
3817-
if (resultArray != null) {
3818-
resultArray.setNullAt(pos)
3772+
@transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
3773+
if (elementTypeSupportEquals) {
3774+
(array1, array2) =>
3775+
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
3776+
val hs = new OpenHashSet[Any]
3777+
var foundNullElement = false
3778+
Seq(array1, array2).foreach { array =>
3779+
var i = 0
3780+
while (i < array.numElements()) {
3781+
if (array.isNullAt(i)) {
3782+
if (!foundNullElement) {
3783+
arrayBuffer += null
3784+
foundNullElement = true
3785+
}
3786+
} else {
3787+
val elem = array.get(i, elementType)
3788+
if (!hs.contains(elem)) {
3789+
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3790+
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
3791+
}
3792+
arrayBuffer += elem
3793+
hs.add(elem)
3794+
}
38193795
}
3820-
pos += 1
3821-
nullElementSize = 1
3796+
i += 1
38223797
}
3823-
} else {
3824-
val assigned = if (!isLongType) {
3825-
assignInt(array, i, resultArray, pos)
3798+
}
3799+
new GenericArrayData(arrayBuffer)
3800+
} else {
3801+
(array1, array2) =>
3802+
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
3803+
var alreadyIncludeNull = false
3804+
Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
3805+
var found = false
3806+
if (elem == null) {
3807+
if (alreadyIncludeNull) {
3808+
found = true
3809+
} else {
3810+
alreadyIncludeNull = true
3811+
}
38263812
} else {
3827-
assignLong(array, i, resultArray, pos)
3813+
// check elem is already stored in arrayBuffer or not?
3814+
var j = 0
3815+
while (!found && j < arrayBuffer.size) {
3816+
val va = arrayBuffer(j)
3817+
if (va != null && ordering.equiv(va, elem)) {
3818+
found = true
3819+
}
3820+
j = j + 1
3821+
}
38283822
}
3829-
if (assigned) {
3830-
pos += 1
3823+
if (!found) {
3824+
if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3825+
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
3826+
}
3827+
arrayBuffer += elem
38313828
}
3832-
}
3833-
i += 1
3834-
}
3829+
}))
3830+
new GenericArrayData(arrayBuffer)
38353831
}
3836-
pos
38373832
}
38383833

38393834
override def nullSafeEval(input1: Any, input2: Any): Any = {
38403835
val array1 = input1.asInstanceOf[ArrayData]
38413836
val array2 = input2.asInstanceOf[ArrayData]
38423837

3843-
if (elementTypeSupportEquals) {
3844-
elementType match {
3845-
case IntegerType =>
3846-
// avoid boxing of primitive int array elements
3847-
// calculate result array size
3848-
hsInt = new OpenHashSet[Int]
3849-
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
3850-
hsInt = new OpenHashSet[Int]
3851-
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
3852-
IntegerType.defaultSize, elements)) {
3853-
new GenericArrayData(new Array[Any](elements))
3854-
} else {
3855-
UnsafeArrayData.forPrimitiveArray(
3856-
Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
3857-
}
3858-
evalIntLongPrimitiveType(array1, array2, resultArray, false)
3859-
resultArray
3860-
case LongType =>
3861-
// avoid boxing of primitive long array elements
3862-
// calculate result array size
3863-
hsLong = new OpenHashSet[Long]
3864-
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
3865-
hsLong = new OpenHashSet[Long]
3866-
val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
3867-
LongType.defaultSize, elements)) {
3868-
new GenericArrayData(new Array[Any](elements))
3869-
} else {
3870-
UnsafeArrayData.forPrimitiveArray(
3871-
Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
3872-
}
3873-
evalIntLongPrimitiveType(array1, array2, resultArray, true)
3874-
resultArray
3875-
case _ =>
3876-
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
3877-
val hs = new OpenHashSet[Any]
3878-
var foundNullElement = false
3879-
Seq(array1, array2).foreach { array =>
3880-
var i = 0
3881-
while (i < array.numElements()) {
3882-
if (array.isNullAt(i)) {
3883-
if (!foundNullElement) {
3884-
arrayBuffer += null
3885-
foundNullElement = true
3886-
}
3887-
} else {
3888-
val elem = array.get(i, elementType)
3889-
if (!hs.contains(elem)) {
3890-
if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3891-
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
3892-
}
3893-
arrayBuffer += elem
3894-
hs.add(elem)
3895-
}
3896-
}
3897-
i += 1
3898-
}
3899-
}
3900-
new GenericArrayData(arrayBuffer)
3901-
}
3902-
} else {
3903-
ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
3904-
}
3838+
evalUnion(array1, array2)
39053839
}
39063840

39073841
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
39083842
val i = ctx.freshName("i")
3909-
val pos = ctx.freshName("pos")
39103843
val value = ctx.freshName("value")
39113844
val size = ctx.freshName("size")
3912-
val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
3913-
if (elementTypeSupportEquals) {
3914-
elementType match {
3915-
case ByteType | ShortType | IntegerType | LongType =>
3916-
val ptName = CodeGenerator.primitiveTypeName(elementType)
3917-
val unsafeArray = ctx.freshName("unsafeArray")
3918-
(if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
3919-
if (elementType == LongType) "Long" else "Int",
3920-
s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
3921-
if (elementType == LongType) "(long)" else "(int)",
3922-
s"""
3923-
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
3924-
|${ev.value} = $unsafeArray;
3925-
""".stripMargin)
3926-
case _ =>
3927-
val genericArrayData = classOf[GenericArrayData].getName
3928-
val et = ctx.addReferenceObj("elementType", elementType)
3929-
("", "Object",
3930-
s"get($i, $et)", s"update($pos, $value)", "Object", "",
3931-
s"${ev.value} = new $genericArrayData(new Object[$size]);")
3932-
}
3933-
} else {
3934-
("", "", "", "", "", "", "")
3935-
}
3845+
if (canUseSpecializedHashSet) {
3846+
val jt = CodeGenerator.javaType(elementType)
3847+
val ptName = CodeGenerator.primitiveTypeName(jt)
39363848

3937-
nullSafeCodeGen(ctx, ev, (array1, array2) => {
3938-
if (openHashElementType != "") {
3939-
// Here, we ensure elementTypeSupportEquals is true
3849+
nullSafeCodeGen(ctx, ev, (array1, array2) => {
39403850
val foundNullElement = ctx.freshName("foundNullElement")
3941-
val openHashSet = classOf[OpenHashSet[_]].getName
3942-
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
3943-
val hs = ctx.freshName("hs")
3944-
val arrayData = classOf[ArrayData].getName
3945-
val arrays = ctx.freshName("arrays")
3851+
val nullElementIndex = ctx.freshName("nullElementIndex")
3852+
val builder = ctx.freshName("builder")
39463853
val array = ctx.freshName("array")
3854+
val arrays = ctx.freshName("arrays")
39473855
val arrayDataIdx = ctx.freshName("arrayDataIdx")
3856+
val openHashSet = classOf[OpenHashSet[_]].getName
3857+
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
3858+
val hashSet = ctx.freshName("hashSet")
3859+
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
3860+
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
3861+
3862+
def withArrayNullAssignment(body: String) =
3863+
if (dataType.asInstanceOf[ArrayType].containsNull) {
3864+
s"""
3865+
|if ($array.isNullAt($i)) {
3866+
| if (!$foundNullElement) {
3867+
| $nullElementIndex = $size;
3868+
| $foundNullElement = true;
3869+
| $size++;
3870+
| $builder.$$plus$$eq($nullValueHolder);
3871+
| }
3872+
|} else {
3873+
| $body
3874+
|}
3875+
""".stripMargin
3876+
} else {
3877+
body
3878+
}
3879+
3880+
val processArray = withArrayNullAssignment(
3881+
s"""
3882+
|$jt $value = ${genGetValue(array, i)};
3883+
|if (!$hashSet.contains($hsValueCast$value)) {
3884+
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
3885+
| break;
3886+
| }
3887+
| $hashSet.add$hsPostFix($hsValueCast$value);
3888+
| $builder.$$plus$$eq($value);
3889+
|}
3890+
""".stripMargin)
3891+
3892+
// Only need to track null element index when result array's element is nullable.
3893+
val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
3894+
s"""
3895+
|boolean $foundNullElement = false;
3896+
|int $nullElementIndex = -1;
3897+
""".stripMargin
3898+
} else {
3899+
""
3900+
}
3901+
39483902
s"""
3949-
|$openHashSet $hs = new $openHashSet$postFix($classTag);
3950-
|boolean $foundNullElement = false;
3951-
|$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
3952-
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
3953-
| $arrayData $array = $arrays[$arrayDataIdx];
3954-
| for (int $i = 0; $i < $array.numElements(); $i++) {
3955-
| if ($array.isNullAt($i)) {
3956-
| $foundNullElement = true;
3957-
| } else {
3958-
| $hs.add$postFix($array.$getter);
3959-
| }
3960-
| }
3961-
|}
3962-
|int $size = $hs.size() + ($foundNullElement ? 1 : 0);
3963-
|$arrayBuilder
3964-
|$hs = new $openHashSet$postFix($classTag);
3965-
|$foundNullElement = false;
3966-
|int $pos = 0;
3903+
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
3904+
|$declareNullTrackVariables
3905+
|int $size = 0;
3906+
|$arrayBuilderClass $builder = new $arrayBuilderClass();
3907+
|ArrayData[] $arrays = new ArrayData[]{$array1, $array2};
39673908
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
3968-
| $arrayData $array = $arrays[$arrayDataIdx];
3909+
| ArrayData $array = $arrays[$arrayDataIdx];
39693910
| for (int $i = 0; $i < $array.numElements(); $i++) {
3970-
| if ($array.isNullAt($i)) {
3971-
| if (!$foundNullElement) {
3972-
| ${ev.value}.setNullAt($pos++);
3973-
| $foundNullElement = true;
3974-
| }
3975-
| } else {
3976-
| $javaTypeName $value = $array.$getter;
3977-
| if (!$hs.contains($castOp $value)) {
3978-
| $hs.add$postFix($value);
3979-
| ${ev.value}.$setter;
3980-
| $pos++;
3981-
| }
3982-
| }
3911+
| $processArray
39833912
| }
39843913
|}
3914+
|${buildResultArray(builder, ev.value, size, nullElementIndex)}
39853915
""".stripMargin
3986-
} else {
3987-
val arrayUnion = classOf[ArrayUnion].getName
3988-
val et = ctx.addReferenceObj("elementTypeUnion", elementType)
3989-
val order = ctx.addReferenceObj("orderingUnion", ordering)
3990-
val method = "unionOrdering"
3991-
s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
3992-
}
3993-
})
3916+
})
3917+
} else {
3918+
nullSafeCodeGen(ctx, ev, (array1, array2) => {
3919+
val expr = ctx.addReferenceObj("arrayUnionExpr", this)
3920+
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
3921+
})
3922+
}
39943923
}
39953924

39963925
override def prettyName: String = "array_union"
@@ -4154,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
41544083
}
41554084

41564085
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
4157-
val arrayData = classOf[ArrayData].getName
41584086
val i = ctx.freshName("i")
41594087
val value = ctx.freshName("value")
41604088
val size = ctx.freshName("size")
@@ -4268,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
42684196
} else {
42694197
nullSafeCodeGen(ctx, ev, (array1, array2) => {
42704198
val expr = ctx.addReferenceObj("arrayIntersectExpr", this)
4271-
s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
4199+
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
42724200
})
42734201
}
42744202
}
@@ -4387,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
43874315
}
43884316

43894317
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
4390-
val arrayData = classOf[ArrayData].getName
43914318
val i = ctx.freshName("i")
43924319
val value = ctx.freshName("value")
43934320
val size = ctx.freshName("size")
@@ -4490,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
44904417
} else {
44914418
nullSafeCodeGen(ctx, ev, (array1, array2) => {
44924419
val expr = ctx.addReferenceObj("arrayExceptExpr", this)
4493-
s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
4420+
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
44944421
})
44954422
}
44964423
}

0 commit comments

Comments
 (0)