@@ -3980,7 +3980,8 @@ object ArrayUnion {
3980
3980
""" ,
3981
3981
since = " 2.4.0" )
3982
3982
case class ArrayExcept (left : Expression , right : Expression ) extends ArraySetLike
3983
- with ComplexTypeMergingExpression {
3983
+ with ComplexTypeMergingExpression {
3984
+
3984
3985
override def dataType : DataType = {
3985
3986
dataTypeCheck
3986
3987
left.dataType
@@ -4077,81 +4078,80 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
4077
4078
override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
4078
4079
val arrayData = classOf [ArrayData ].getName
4079
4080
val i = ctx.freshName(" i" )
4080
- val pos = ctx.freshName(" pos" )
4081
4081
val value = ctx.freshName(" value" )
4082
- val hsValue = ctx.freshName(" hsValue" )
4083
4082
val size = ctx.freshName(" size" )
4084
- if (elementTypeSupportEquals) {
4085
- val ptName = CodeGenerator .primitiveTypeName(elementType)
4086
- val unsafeArray = ctx.freshName(" unsafeArray" )
4087
- val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
4088
- getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) =
4089
- elementType match {
4090
- case ByteType | ShortType | IntegerType =>
4091
- (" $mcI$sp" , " Int" , " int" , s " (int) $value" ,
4092
- s " get $ptName( $i) " , s " set $ptName( $pos, $value) " ,
4093
- CodeGenerator .javaType(elementType), ptName,
4094
- s """
4095
- | ${ctx.createUnsafeArray(unsafeArray, size, elementType, s " $prettyName failed. " )}
4096
- | ${ev.value} = $unsafeArray;
4097
- """ .stripMargin)
4098
- case LongType | FloatType | DoubleType =>
4099
- val signature = elementType match {
4100
- case LongType => " $mcJ$sp"
4101
- case FloatType => " $mcF$sp"
4102
- case DoubleType => " $mcD$sp"
4103
- }
4104
- (signature, CodeGenerator .boxedType(elementType),
4105
- CodeGenerator .javaType(elementType), value,
4106
- s " get $ptName( $i) " , s " set $ptName( $pos, $value) " ,
4107
- CodeGenerator .javaType(elementType), ptName,
4108
- s """
4109
- | ${ctx.createUnsafeArray(unsafeArray, size, elementType, s " $prettyName failed. " )}
4110
- | ${ev.value} = $unsafeArray;
4111
- """ .stripMargin)
4112
- case _ =>
4113
- val genericArrayData = classOf [GenericArrayData ].getName
4114
- val et = ctx.addReferenceObj(" elementType" , elementType)
4115
- (" " , " Object" , " Object" , value,
4116
- s " get( $i, $et) " , s " update( $pos, $value) " , " Object" , " Ref" ,
4117
- s " ${ev.value} = new $genericArrayData(new Object[ $size]); " )
4118
- }
4083
+ val canUseSpecializedHashSet = elementType match {
4084
+ case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
4085
+ case _ => false
4086
+ }
4087
+ if (canUseSpecializedHashSet) {
4088
+ val jt = CodeGenerator .javaType(elementType)
4089
+ val ptName = CodeGenerator .primitiveTypeName(jt)
4090
+
4091
+ def genGetValue (array : String ): String =
4092
+ CodeGenerator .getValue(array, elementType, i)
4093
+
4094
+ val (hsPostFix, hsTypeName) = elementType match {
4095
+ // we cast byte/short to int when writing to the hash set.
4096
+ case ByteType | ShortType | IntegerType => (" $mcI$sp" , " Int" )
4097
+ case LongType => (" $mcJ$sp" , ptName)
4098
+ case FloatType => (" $mcF$sp" , ptName)
4099
+ case DoubleType => (" $mcD$sp" , ptName)
4100
+ }
4101
+
4102
+ // we cast byte/short to int when writing to the hash set.
4103
+ val hsValueCast = elementType match {
4104
+ case ByteType | ShortType => " (int) "
4105
+ case _ => " "
4106
+ }
4119
4107
4120
4108
nullSafeCodeGen(ctx, ev, (array1, array2) => {
4121
4109
val notFoundNullElement = ctx.freshName(" notFoundNullElement" )
4122
4110
val nullElementIndex = ctx.freshName(" nullElementIndex" )
4123
4111
val builder = ctx.freshName(" builder" )
4124
- val array = ctx.freshName(" array" )
4125
4112
val openHashSet = classOf [OpenHashSet [_]].getName
4126
- val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $openHashElementType () "
4127
- val hs = ctx.freshName(" hs " )
4113
+ val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $hsTypeName () "
4114
+ val hashSet = ctx.freshName(" hashSet " )
4128
4115
val genericArrayData = classOf [GenericArrayData ].getName
4129
4116
val arrayBuilder = " scala.collection.mutable.ArrayBuilder"
4130
- val arrayBuilderClass = s " $arrayBuilder$$ of $primitiveTypeName"
4131
- val arrayBuilderClassTag = if (primitiveTypeName != " Ref" ) {
4132
- s " scala.reflect.ClassTag $$ .MODULE $$ . $primitiveTypeName() "
4133
- } else {
4134
- s " scala.reflect.ClassTag $$ .MODULE $$ .AnyRef() "
4135
- }
4117
+ val arrayBuilderClass = s " $arrayBuilder$$ of $ptName"
4118
+ val arrayBuilderClassTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $ptName() "
4136
4119
4137
- def withArray2NullCheck (body : String ) =
4120
+ def withArray2NullCheck (body : String ): String =
4138
4121
if (right.dataType.asInstanceOf [ArrayType ].containsNull) {
4139
- s """
4140
- |if ( $array2.isNullAt( $i)) {
4141
- | $notFoundNullElement = false;
4142
- |} else {
4143
- | $body
4144
- |}
4122
+ if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
4123
+ s """
4124
+ |if ( $array2.isNullAt( $i)) {
4125
+ | $notFoundNullElement = false;
4126
+ |} else {
4127
+ | $body
4128
+ |}
4145
4129
""" .stripMargin
4130
+ } else {
4131
+ // if array1's element is not nullable, we don't need to track the null element index.
4132
+ s """
4133
+ |if (! $array2.isNullAt( $i)) {
4134
+ | $body
4135
+ |}
4136
+ """ .stripMargin
4137
+ }
4146
4138
} else {
4147
4139
body
4148
4140
}
4149
- val array2Body =
4141
+
4142
+ val writeArray2ToHashSet = withArray2NullCheck(
4150
4143
s """
4151
- | $javaTypeName $value = $array2. $getter;
4152
- | $hsJavaTypeName $hsValue = $genHsValue;
4153
- | $hs.add $postFix( $hsValue);
4154
- """ .stripMargin
4144
+ | $jt $value = ${genGetValue(array2)};
4145
+ | $hashSet.add $hsPostFix( $hsValueCast$value);
4146
+ """ .stripMargin)
4147
+
4148
+ // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will
4149
+ // convert ArrayBuilder to ArrayData and setNull on the slot with null holder.
4150
+ val nullValueHolder = elementType match {
4151
+ case ByteType => " (byte) 0"
4152
+ case ShortType => " (short) 0"
4153
+ case _ => " 0"
4154
+ }
4155
4155
4156
4156
def withArray1NullAssignment (body : String ) =
4157
4157
if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
@@ -4161,6 +4161,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
4161
4161
| $nullElementIndex = $size;
4162
4162
| $notFoundNullElement = false;
4163
4163
| $size++;
4164
+ | $builder. $$ plus $$ eq( $nullValueHolder);
4164
4165
| }
4165
4166
|} else {
4166
4167
| $body
@@ -4169,81 +4170,71 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
4169
4170
} else {
4170
4171
body
4171
4172
}
4172
- val array1Body =
4173
+
4174
+ val processArray1 = withArray1NullAssignment(
4173
4175
s """
4174
- | $javaTypeName $value = $array1. $getter;
4175
- | $hsJavaTypeName $hsValue = $genHsValue;
4176
- |if (! $hs.contains( $hsValue)) {
4176
+ | $jt $value = ${genGetValue(array1)};
4177
+ |if (! $hashSet.contains( $hsValueCast$value)) {
4177
4178
| if (++ $size > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
4178
4179
| break;
4179
4180
| }
4180
- | $hs .add $postFix ( $hsValue );
4181
+ | $hashSet .add $hsPostFix ( $hsValueCast$value );
4181
4182
| $builder. $$ plus $$ eq( $value);
4182
4183
|}
4183
- """ .stripMargin
4184
+ """ .stripMargin)
4184
4185
4185
- val nonNullArrayDataBuild = {
4186
- val build = if (postFix != " " ) {
4187
- val defaultSize = elementType.defaultSize
4186
+ def withResultArrayNullCheck (body : String ): String = {
4187
+ if (dataType.asInstanceOf [ArrayType ].containsNull) {
4188
4188
s """
4189
- |if (!UnsafeArrayData.shouldUseGenericArrayData( $defaultSize , $size )) {
4190
- | ${ev.value} = UnsafeArrayData.fromPrimitiveArray( $builder .result());
4191
- |} else {
4192
- | ${ev.value} = new $genericArrayData ( $builder .result() );
4189
+ | $body
4190
+ |if ( $nullElementIndex >= 0) {
4191
+ | // result has null element
4192
+ | ${ev.value}.setNullAt( $nullElementIndex );
4193
4193
|}
4194
4194
""" .stripMargin
4195
4195
} else {
4196
- s " ${ev.value} = new $genericArrayData ( $builder .result()); "
4196
+ body
4197
4197
}
4198
+ }
4199
+
4200
+ val buildResultArray = withResultArrayNullCheck(
4198
4201
s """
4199
4202
|if ( $size > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
4200
- | throw new RuntimeException("Unsuccessful try create array with " + $size +
4203
+ | throw new RuntimeException("Cannot create array with " + $size +
4201
4204
| " bytes of data due to exceeding the limit " +
4202
- | " ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH } elements for GenericArrayData." +
4203
- | " $prettyName failed.");
4205
+ | " ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH } elements for ArrayData.");
4204
4206
|}
4205
- | $build
4207
+ |
4208
+ |if (!UnsafeArrayData.shouldUseGenericArrayData( ${elementType.defaultSize}, $size)) {
4209
+ | ${ev.value} = UnsafeArrayData.fromPrimitiveArray( $builder.result());
4210
+ |} else {
4211
+ | ${ev.value} = new $genericArrayData( $builder.result());
4212
+ |}
4213
+ """ .stripMargin)
4214
+
4215
+ // Only need to track null element index when array1's element is nullable.
4216
+ val declareNullTrackVariables = if (left.dataType.asInstanceOf [ArrayType ].containsNull) {
4217
+ s """
4218
+ |boolean $notFoundNullElement = true;
4219
+ |int $nullElementIndex = -1;
4206
4220
""" .stripMargin
4221
+ } else {
4222
+ " "
4207
4223
}
4208
4224
4209
- def buildResultArrayData (nonNullArrayDataBuild : String ) =
4210
- if (dataType.asInstanceOf [ArrayType ].containsNull) {
4211
- s """
4212
- |if ( $nullElementIndex < 0) {
4213
- | // result has no null element
4214
- | $nonNullArrayDataBuild
4215
- |} else {
4216
- | // result has null element
4217
- | $arrayDataBuilder
4218
- | $javaTypeName[] $array = $builder.result();
4219
- | for (int $i = 0, $pos = 0; $pos < $size; $pos++) {
4220
- | if ( $pos == $nullElementIndex) {
4221
- | ${ev.value}.setNullAt( $pos);
4222
- | } else {
4223
- | $javaTypeName $value = $array[ $i++];
4224
- | ${ev.value}. $setter;
4225
- | }
4226
- | }
4227
- |}
4228
- """ .stripMargin
4229
- } else {
4230
- nonNullArrayDataBuild
4231
- }
4232
-
4233
4225
s """
4234
- | $openHashSet $hs = new $openHashSet$postFix ( $classTag);
4235
- |boolean $notFoundNullElement = true;
4226
+ | $openHashSet $hashSet = new $openHashSet$hsPostFix ( $classTag);
4227
+ | $declareNullTrackVariables
4236
4228
|for (int $i = 0; $i < $array2.numElements(); $i++) {
4237
- | ${withArray2NullCheck(array2Body)}
4229
+ | $writeArray2ToHashSet
4238
4230
|}
4239
4231
| $arrayBuilderClass $builder =
4240
4232
| ( $arrayBuilderClass) $arrayBuilder.make( $arrayBuilderClassTag);
4241
- |int $nullElementIndex = -1;
4242
4233
|int $size = 0;
4243
4234
|for (int $i = 0; $i < $array1.numElements(); $i++) {
4244
- | ${withArray1NullAssignment(array1Body)}
4235
+ | $processArray1
4245
4236
|}
4246
- | ${buildResultArrayData(nonNullArrayDataBuild)}
4237
+ | $buildResultArray
4247
4238
""" .stripMargin
4248
4239
})
4249
4240
} else {
0 commit comments