Skip to content

Commit 684c719

Browse files
cloud-fanueshin
authored andcommitted
[SPARK-23915][SQL][FOLLOWUP] Add array_except function
## What changes were proposed in this pull request? simplify the codegen: 1. only do real codegen if the type can be specialized by the hash set 2. change the null handling. Before: track the nullElementIndex, and create a new ArrayData to insert the null in the middle. After: track the nullElementIndex, put a null placeholder in the ArrayBuilder, at the end create ArrayData from ArrayBuilder directly. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes apache#21966 from cloud-fan/minor2.
1 parent 0ecc132 commit 684c719

File tree

1 file changed

+98
-107
lines changed

1 file changed

+98
-107
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 98 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,7 +3980,8 @@ object ArrayUnion {
39803980
""",
39813981
since = "2.4.0")
39823982
case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
3983-
with ComplexTypeMergingExpression {
3983+
with ComplexTypeMergingExpression {
3984+
39843985
override def dataType: DataType = {
39853986
dataTypeCheck
39863987
left.dataType
@@ -4077,81 +4078,80 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
40774078
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
40784079
val arrayData = classOf[ArrayData].getName
40794080
val i = ctx.freshName("i")
4080-
val pos = ctx.freshName("pos")
40814081
val value = ctx.freshName("value")
4082-
val hsValue = ctx.freshName("hsValue")
40834082
val size = ctx.freshName("size")
4084-
if (elementTypeSupportEquals) {
4085-
val ptName = CodeGenerator.primitiveTypeName(elementType)
4086-
val unsafeArray = ctx.freshName("unsafeArray")
4087-
val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
4088-
getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) =
4089-
elementType match {
4090-
case ByteType | ShortType | IntegerType =>
4091-
("$mcI$sp", "Int", "int", s"(int) $value",
4092-
s"get$ptName($i)", s"set$ptName($pos, $value)",
4093-
CodeGenerator.javaType(elementType), ptName,
4094-
s"""
4095-
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
4096-
|${ev.value} = $unsafeArray;
4097-
""".stripMargin)
4098-
case LongType | FloatType | DoubleType =>
4099-
val signature = elementType match {
4100-
case LongType => "$mcJ$sp"
4101-
case FloatType => "$mcF$sp"
4102-
case DoubleType => "$mcD$sp"
4103-
}
4104-
(signature, CodeGenerator.boxedType(elementType),
4105-
CodeGenerator.javaType(elementType), value,
4106-
s"get$ptName($i)", s"set$ptName($pos, $value)",
4107-
CodeGenerator.javaType(elementType), ptName,
4108-
s"""
4109-
|${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
4110-
|${ev.value} = $unsafeArray;
4111-
""".stripMargin)
4112-
case _ =>
4113-
val genericArrayData = classOf[GenericArrayData].getName
4114-
val et = ctx.addReferenceObj("elementType", elementType)
4115-
("", "Object", "Object", value,
4116-
s"get($i, $et)", s"update($pos, $value)", "Object", "Ref",
4117-
s"${ev.value} = new $genericArrayData(new Object[$size]);")
4118-
}
4083+
val canUseSpecializedHashSet = elementType match {
4084+
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
4085+
case _ => false
4086+
}
4087+
if (canUseSpecializedHashSet) {
4088+
val jt = CodeGenerator.javaType(elementType)
4089+
val ptName = CodeGenerator.primitiveTypeName(jt)
4090+
4091+
def genGetValue(array: String): String =
4092+
CodeGenerator.getValue(array, elementType, i)
4093+
4094+
val (hsPostFix, hsTypeName) = elementType match {
4095+
// we cast byte/short to int when writing to the hash set.
4096+
case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
4097+
case LongType => ("$mcJ$sp", ptName)
4098+
case FloatType => ("$mcF$sp", ptName)
4099+
case DoubleType => ("$mcD$sp", ptName)
4100+
}
4101+
4102+
// we cast byte/short to int when writing to the hash set.
4103+
val hsValueCast = elementType match {
4104+
case ByteType | ShortType => "(int) "
4105+
case _ => ""
4106+
}
41194107

41204108
nullSafeCodeGen(ctx, ev, (array1, array2) => {
41214109
val notFoundNullElement = ctx.freshName("notFoundNullElement")
41224110
val nullElementIndex = ctx.freshName("nullElementIndex")
41234111
val builder = ctx.freshName("builder")
4124-
val array = ctx.freshName("array")
41254112
val openHashSet = classOf[OpenHashSet[_]].getName
4126-
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
4127-
val hs = ctx.freshName("hs")
4113+
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
4114+
val hashSet = ctx.freshName("hashSet")
41284115
val genericArrayData = classOf[GenericArrayData].getName
41294116
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
4130-
val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName"
4131-
val arrayBuilderClassTag = if (primitiveTypeName != "Ref") {
4132-
s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()"
4133-
} else {
4134-
s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()"
4135-
}
4117+
val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
4118+
val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
41364119

4137-
def withArray2NullCheck(body: String) =
4120+
def withArray2NullCheck(body: String): String =
41384121
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
4139-
s"""
4140-
|if ($array2.isNullAt($i)) {
4141-
| $notFoundNullElement = false;
4142-
|} else {
4143-
| $body
4144-
|}
4122+
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
4123+
s"""
4124+
|if ($array2.isNullAt($i)) {
4125+
| $notFoundNullElement = false;
4126+
|} else {
4127+
| $body
4128+
|}
41454129
""".stripMargin
4130+
} else {
4131+
// if array1's element is not nullable, we don't need to track the null element index.
4132+
s"""
4133+
|if (!$array2.isNullAt($i)) {
4134+
| $body
4135+
|}
4136+
""".stripMargin
4137+
}
41464138
} else {
41474139
body
41484140
}
4149-
val array2Body =
4141+
4142+
val writeArray2ToHashSet = withArray2NullCheck(
41504143
s"""
4151-
|$javaTypeName $value = $array2.$getter;
4152-
|$hsJavaTypeName $hsValue = $genHsValue;
4153-
|$hs.add$postFix($hsValue);
4154-
""".stripMargin
4144+
|$jt $value = ${genGetValue(array2)};
4145+
|$hashSet.add$hsPostFix($hsValueCast$value);
4146+
""".stripMargin)
4147+
4148+
// When hitting a null value, put a null holder in the ArrayBuilder. Finally we will
4149+
// convert ArrayBuilder to ArrayData and setNull on the slot with null holder.
4150+
val nullValueHolder = elementType match {
4151+
case ByteType => "(byte) 0"
4152+
case ShortType => "(short) 0"
4153+
case _ => "0"
4154+
}
41554155

41564156
def withArray1NullAssignment(body: String) =
41574157
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
@@ -4161,6 +4161,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
41614161
| $nullElementIndex = $size;
41624162
| $notFoundNullElement = false;
41634163
| $size++;
4164+
| $builder.$$plus$$eq($nullValueHolder);
41644165
| }
41654166
|} else {
41664167
| $body
@@ -4169,81 +4170,71 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
41694170
} else {
41704171
body
41714172
}
4172-
val array1Body =
4173+
4174+
val processArray1 = withArray1NullAssignment(
41734175
s"""
4174-
|$javaTypeName $value = $array1.$getter;
4175-
|$hsJavaTypeName $hsValue = $genHsValue;
4176-
|if (!$hs.contains($hsValue)) {
4176+
|$jt $value = ${genGetValue(array1)};
4177+
|if (!$hashSet.contains($hsValueCast$value)) {
41774178
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
41784179
| break;
41794180
| }
4180-
| $hs.add$postFix($hsValue);
4181+
| $hashSet.add$hsPostFix($hsValueCast$value);
41814182
| $builder.$$plus$$eq($value);
41824183
|}
4183-
""".stripMargin
4184+
""".stripMargin)
41844185

4185-
val nonNullArrayDataBuild = {
4186-
val build = if (postFix != "") {
4187-
val defaultSize = elementType.defaultSize
4186+
def withResultArrayNullCheck(body: String): String = {
4187+
if (dataType.asInstanceOf[ArrayType].containsNull) {
41884188
s"""
4189-
|if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) {
4190-
| ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result());
4191-
|} else {
4192-
| ${ev.value} = new $genericArrayData($builder.result());
4189+
|$body
4190+
|if ($nullElementIndex >= 0) {
4191+
| // result has null element
4192+
| ${ev.value}.setNullAt($nullElementIndex);
41934193
|}
41944194
""".stripMargin
41954195
} else {
4196-
s"${ev.value} = new $genericArrayData($builder.result());"
4196+
body
41974197
}
4198+
}
4199+
4200+
val buildResultArray = withResultArrayNullCheck(
41984201
s"""
41994202
|if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
4200-
| throw new RuntimeException("Unsuccessful try create array with " + $size +
4203+
| throw new RuntimeException("Cannot create array with " + $size +
42014204
| " bytes of data due to exceeding the limit " +
4202-
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." +
4203-
| " $prettyName failed.");
4205+
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData.");
42044206
|}
4205-
|$build
4207+
|
4208+
|if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) {
4209+
| ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result());
4210+
|} else {
4211+
| ${ev.value} = new $genericArrayData($builder.result());
4212+
|}
4213+
""".stripMargin)
4214+
4215+
// Only need to track null element index when array1's element is nullable.
4216+
val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
4217+
s"""
4218+
|boolean $notFoundNullElement = true;
4219+
|int $nullElementIndex = -1;
42064220
""".stripMargin
4221+
} else {
4222+
""
42074223
}
42084224

4209-
def buildResultArrayData(nonNullArrayDataBuild: String) =
4210-
if (dataType.asInstanceOf[ArrayType].containsNull) {
4211-
s"""
4212-
|if ($nullElementIndex < 0) {
4213-
| // result has no null element
4214-
| $nonNullArrayDataBuild
4215-
|} else {
4216-
| // result has null element
4217-
| $arrayDataBuilder
4218-
| $javaTypeName[] $array = $builder.result();
4219-
| for (int $i = 0, $pos = 0; $pos < $size; $pos++) {
4220-
| if ($pos == $nullElementIndex) {
4221-
| ${ev.value}.setNullAt($pos);
4222-
| } else {
4223-
| $javaTypeName $value = $array[$i++];
4224-
| ${ev.value}.$setter;
4225-
| }
4226-
| }
4227-
|}
4228-
""".stripMargin
4229-
} else {
4230-
nonNullArrayDataBuild
4231-
}
4232-
42334225
s"""
4234-
|$openHashSet $hs = new $openHashSet$postFix($classTag);
4235-
|boolean $notFoundNullElement = true;
4226+
|$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
4227+
|$declareNullTrackVariables
42364228
|for (int $i = 0; $i < $array2.numElements(); $i++) {
4237-
| ${withArray2NullCheck(array2Body)}
4229+
| $writeArray2ToHashSet
42384230
|}
42394231
|$arrayBuilderClass $builder =
42404232
| ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
4241-
|int $nullElementIndex = -1;
42424233
|int $size = 0;
42434234
|for (int $i = 0; $i < $array1.numElements(); $i++) {
4244-
| ${withArray1NullAssignment(array1Body)}
4235+
| $processArray1
42454236
|}
4246-
|${buildResultArrayData(nonNullArrayDataBuild)}
4237+
|$buildResultArray
42474238
""".stripMargin
42484239
})
42494240
} else {

0 commit comments

Comments
 (0)