@@ -3767,230 +3767,159 @@ object ArraySetLike {
3767
3767
""" ,
3768
3768
since = " 2.4.0" )
3769
3769
case class ArrayUnion (left : Expression , right : Expression ) extends ArraySetLike
3770
- with ComplexTypeMergingExpression {
3771
- var hsInt : OpenHashSet [Int ] = _
3772
- var hsLong : OpenHashSet [Long ] = _
3773
-
3774
- def assignInt (array : ArrayData , idx : Int , resultArray : ArrayData , pos : Int ): Boolean = {
3775
- val elem = array.getInt(idx)
3776
- if (! hsInt.contains(elem)) {
3777
- if (resultArray != null ) {
3778
- resultArray.setInt(pos, elem)
3779
- }
3780
- hsInt.add(elem)
3781
- true
3782
- } else {
3783
- false
3784
- }
3785
- }
3786
-
3787
- def assignLong (array : ArrayData , idx : Int , resultArray : ArrayData , pos : Int ): Boolean = {
3788
- val elem = array.getLong(idx)
3789
- if (! hsLong.contains(elem)) {
3790
- if (resultArray != null ) {
3791
- resultArray.setLong(pos, elem)
3792
- }
3793
- hsLong.add(elem)
3794
- true
3795
- } else {
3796
- false
3797
- }
3798
- }
3770
+ with ComplexTypeMergingExpression {
3799
3771
3800
- def evalIntLongPrimitiveType (
3801
- array1 : ArrayData ,
3802
- array2 : ArrayData ,
3803
- resultArray : ArrayData ,
3804
- isLongType : Boolean ): Int = {
3805
- // store elements into resultArray
3806
- var nullElementSize = 0
3807
- var pos = 0
3808
- Seq (array1, array2).foreach { array =>
3809
- var i = 0
3810
- while (i < array.numElements()) {
3811
- val size = if (! isLongType) hsInt.size else hsLong.size
3812
- if (size + nullElementSize > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3813
- ArraySetLike .throwUnionLengthOverflowException(size)
3814
- }
3815
- if (array.isNullAt(i)) {
3816
- if (nullElementSize == 0 ) {
3817
- if (resultArray != null ) {
3818
- resultArray.setNullAt(pos)
3772
+ @ transient lazy val evalUnion : (ArrayData , ArrayData ) => ArrayData = {
3773
+ if (elementTypeSupportEquals) {
3774
+ (array1, array2) =>
3775
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
3776
+ val hs = new OpenHashSet [Any ]
3777
+ var foundNullElement = false
3778
+ Seq (array1, array2).foreach { array =>
3779
+ var i = 0
3780
+ while (i < array.numElements()) {
3781
+ if (array.isNullAt(i)) {
3782
+ if (! foundNullElement) {
3783
+ arrayBuffer += null
3784
+ foundNullElement = true
3785
+ }
3786
+ } else {
3787
+ val elem = array.get(i, elementType)
3788
+ if (! hs.contains(elem)) {
3789
+ if (arrayBuffer.size > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3790
+ ArraySetLike .throwUnionLengthOverflowException(arrayBuffer.size)
3791
+ }
3792
+ arrayBuffer += elem
3793
+ hs.add(elem)
3794
+ }
3819
3795
}
3820
- pos += 1
3821
- nullElementSize = 1
3796
+ i += 1
3822
3797
}
3823
- } else {
3824
- val assigned = if (! isLongType) {
3825
- assignInt(array, i, resultArray, pos)
3798
+ }
3799
+ new GenericArrayData (arrayBuffer)
3800
+ } else {
3801
+ (array1, array2) =>
3802
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
3803
+ var alreadyIncludeNull = false
3804
+ Seq (array1, array2).foreach(_.foreach(elementType, (_, elem) => {
3805
+ var found = false
3806
+ if (elem == null ) {
3807
+ if (alreadyIncludeNull) {
3808
+ found = true
3809
+ } else {
3810
+ alreadyIncludeNull = true
3811
+ }
3826
3812
} else {
3827
- assignLong(array, i, resultArray, pos)
3813
+ // check elem is already stored in arrayBuffer or not?
3814
+ var j = 0
3815
+ while (! found && j < arrayBuffer.size) {
3816
+ val va = arrayBuffer(j)
3817
+ if (va != null && ordering.equiv(va, elem)) {
3818
+ found = true
3819
+ }
3820
+ j = j + 1
3821
+ }
3828
3822
}
3829
- if (assigned) {
3830
- pos += 1
3823
+ if (! found) {
3824
+ if (arrayBuffer.length > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3825
+ ArraySetLike .throwUnionLengthOverflowException(arrayBuffer.length)
3826
+ }
3827
+ arrayBuffer += elem
3831
3828
}
3832
- }
3833
- i += 1
3834
- }
3829
+ }))
3830
+ new GenericArrayData (arrayBuffer)
3835
3831
}
3836
- pos
3837
3832
}
3838
3833
3839
3834
override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
3840
3835
val array1 = input1.asInstanceOf [ArrayData ]
3841
3836
val array2 = input2.asInstanceOf [ArrayData ]
3842
3837
3843
- if (elementTypeSupportEquals) {
3844
- elementType match {
3845
- case IntegerType =>
3846
- // avoid boxing of primitive int array elements
3847
- // calculate result array size
3848
- hsInt = new OpenHashSet [Int ]
3849
- val elements = evalIntLongPrimitiveType(array1, array2, null , false )
3850
- hsInt = new OpenHashSet [Int ]
3851
- val resultArray = if (UnsafeArrayData .shouldUseGenericArrayData(
3852
- IntegerType .defaultSize, elements)) {
3853
- new GenericArrayData (new Array [Any ](elements))
3854
- } else {
3855
- UnsafeArrayData .forPrimitiveArray(
3856
- Platform .INT_ARRAY_OFFSET , elements, IntegerType .defaultSize)
3857
- }
3858
- evalIntLongPrimitiveType(array1, array2, resultArray, false )
3859
- resultArray
3860
- case LongType =>
3861
- // avoid boxing of primitive long array elements
3862
- // calculate result array size
3863
- hsLong = new OpenHashSet [Long ]
3864
- val elements = evalIntLongPrimitiveType(array1, array2, null , true )
3865
- hsLong = new OpenHashSet [Long ]
3866
- val resultArray = if (UnsafeArrayData .shouldUseGenericArrayData(
3867
- LongType .defaultSize, elements)) {
3868
- new GenericArrayData (new Array [Any ](elements))
3869
- } else {
3870
- UnsafeArrayData .forPrimitiveArray(
3871
- Platform .LONG_ARRAY_OFFSET , elements, LongType .defaultSize)
3872
- }
3873
- evalIntLongPrimitiveType(array1, array2, resultArray, true )
3874
- resultArray
3875
- case _ =>
3876
- val arrayBuffer = new scala.collection.mutable.ArrayBuffer [Any ]
3877
- val hs = new OpenHashSet [Any ]
3878
- var foundNullElement = false
3879
- Seq (array1, array2).foreach { array =>
3880
- var i = 0
3881
- while (i < array.numElements()) {
3882
- if (array.isNullAt(i)) {
3883
- if (! foundNullElement) {
3884
- arrayBuffer += null
3885
- foundNullElement = true
3886
- }
3887
- } else {
3888
- val elem = array.get(i, elementType)
3889
- if (! hs.contains(elem)) {
3890
- if (arrayBuffer.size > ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH ) {
3891
- ArraySetLike .throwUnionLengthOverflowException(arrayBuffer.size)
3892
- }
3893
- arrayBuffer += elem
3894
- hs.add(elem)
3895
- }
3896
- }
3897
- i += 1
3898
- }
3899
- }
3900
- new GenericArrayData (arrayBuffer)
3901
- }
3902
- } else {
3903
- ArrayUnion .unionOrdering(array1, array2, elementType, ordering)
3904
- }
3838
+ evalUnion(array1, array2)
3905
3839
}
3906
3840
3907
3841
override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
3908
3842
val i = ctx.freshName(" i" )
3909
- val pos = ctx.freshName(" pos" )
3910
3843
val value = ctx.freshName(" value" )
3911
3844
val size = ctx.freshName(" size" )
3912
- val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
3913
- if (elementTypeSupportEquals) {
3914
- elementType match {
3915
- case ByteType | ShortType | IntegerType | LongType =>
3916
- val ptName = CodeGenerator .primitiveTypeName(elementType)
3917
- val unsafeArray = ctx.freshName(" unsafeArray" )
3918
- (if (elementType == LongType ) s " $$ mcJ $$ sp " else s " $$ mcI $$ sp " ,
3919
- if (elementType == LongType ) " Long" else " Int" ,
3920
- s " get $ptName( $i) " , s " set $ptName( $pos, $value) " , CodeGenerator .javaType(elementType),
3921
- if (elementType == LongType ) " (long)" else " (int)" ,
3922
- s """
3923
- | ${ctx.createUnsafeArray(unsafeArray, size, elementType, s " $prettyName failed. " )}
3924
- | ${ev.value} = $unsafeArray;
3925
- """ .stripMargin)
3926
- case _ =>
3927
- val genericArrayData = classOf [GenericArrayData ].getName
3928
- val et = ctx.addReferenceObj(" elementType" , elementType)
3929
- (" " , " Object" ,
3930
- s " get( $i, $et) " , s " update( $pos, $value) " , " Object" , " " ,
3931
- s " ${ev.value} = new $genericArrayData(new Object[ $size]); " )
3932
- }
3933
- } else {
3934
- (" " , " " , " " , " " , " " , " " , " " )
3935
- }
3845
+ if (canUseSpecializedHashSet) {
3846
+ val jt = CodeGenerator .javaType(elementType)
3847
+ val ptName = CodeGenerator .primitiveTypeName(jt)
3936
3848
3937
- nullSafeCodeGen(ctx, ev, (array1, array2) => {
3938
- if (openHashElementType != " " ) {
3939
- // Here, we ensure elementTypeSupportEquals is true
3849
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
3940
3850
val foundNullElement = ctx.freshName(" foundNullElement" )
3941
- val openHashSet = classOf [OpenHashSet [_]].getName
3942
- val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $openHashElementType() "
3943
- val hs = ctx.freshName(" hs" )
3944
- val arrayData = classOf [ArrayData ].getName
3945
- val arrays = ctx.freshName(" arrays" )
3851
+ val nullElementIndex = ctx.freshName(" nullElementIndex" )
3852
+ val builder = ctx.freshName(" builder" )
3946
3853
val array = ctx.freshName(" array" )
3854
+ val arrays = ctx.freshName(" arrays" )
3947
3855
val arrayDataIdx = ctx.freshName(" arrayDataIdx" )
3856
+ val openHashSet = classOf [OpenHashSet [_]].getName
3857
+ val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ . $hsTypeName() "
3858
+ val hashSet = ctx.freshName(" hashSet" )
3859
+ val arrayBuilder = classOf [mutable.ArrayBuilder [_]].getName
3860
+ val arrayBuilderClass = s " $arrayBuilder$$ of $ptName"
3861
+
3862
+ def withArrayNullAssignment (body : String ) =
3863
+ if (dataType.asInstanceOf [ArrayType ].containsNull) {
3864
+ s """
3865
+ |if ( $array.isNullAt( $i)) {
3866
+ | if (! $foundNullElement) {
3867
+ | $nullElementIndex = $size;
3868
+ | $foundNullElement = true;
3869
+ | $size++;
3870
+ | $builder. $$ plus $$ eq( $nullValueHolder);
3871
+ | }
3872
+ |} else {
3873
+ | $body
3874
+ |}
3875
+ """ .stripMargin
3876
+ } else {
3877
+ body
3878
+ }
3879
+
3880
+ val processArray = withArrayNullAssignment(
3881
+ s """
3882
+ | $jt $value = ${genGetValue(array, i)};
3883
+ |if (! $hashSet.contains( $hsValueCast$value)) {
3884
+ | if (++ $size > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
3885
+ | break;
3886
+ | }
3887
+ | $hashSet.add $hsPostFix( $hsValueCast$value);
3888
+ | $builder. $$ plus $$ eq( $value);
3889
+ |}
3890
+ """ .stripMargin)
3891
+
3892
+ // Only need to track null element index when result array's element is nullable.
3893
+ val declareNullTrackVariables = if (dataType.asInstanceOf [ArrayType ].containsNull) {
3894
+ s """
3895
+ |boolean $foundNullElement = false;
3896
+ |int $nullElementIndex = -1;
3897
+ """ .stripMargin
3898
+ } else {
3899
+ " "
3900
+ }
3901
+
3948
3902
s """
3949
- | $openHashSet $hs = new $openHashSet$postFix( $classTag);
3950
- |boolean $foundNullElement = false;
3951
- | $arrayData[] $arrays = new $arrayData[]{ $array1, $array2};
3952
- |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
3953
- | $arrayData $array = $arrays[ $arrayDataIdx];
3954
- | for (int $i = 0; $i < $array.numElements(); $i++) {
3955
- | if ( $array.isNullAt( $i)) {
3956
- | $foundNullElement = true;
3957
- | } else {
3958
- | $hs.add $postFix( $array. $getter);
3959
- | }
3960
- | }
3961
- |}
3962
- |int $size = $hs.size() + ( $foundNullElement ? 1 : 0);
3963
- | $arrayBuilder
3964
- | $hs = new $openHashSet$postFix( $classTag);
3965
- | $foundNullElement = false;
3966
- |int $pos = 0;
3903
+ | $openHashSet $hashSet = new $openHashSet$hsPostFix( $classTag);
3904
+ | $declareNullTrackVariables
3905
+ |int $size = 0;
3906
+ | $arrayBuilderClass $builder = new $arrayBuilderClass();
3907
+ |ArrayData[] $arrays = new ArrayData[]{ $array1, $array2};
3967
3908
|for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
3968
- | $arrayData $array = $arrays[ $arrayDataIdx];
3909
+ | ArrayData $array = $arrays[ $arrayDataIdx];
3969
3910
| for (int $i = 0; $i < $array.numElements(); $i++) {
3970
- | if ( $array.isNullAt( $i)) {
3971
- | if (! $foundNullElement) {
3972
- | ${ev.value}.setNullAt( $pos++);
3973
- | $foundNullElement = true;
3974
- | }
3975
- | } else {
3976
- | $javaTypeName $value = $array. $getter;
3977
- | if (! $hs.contains( $castOp $value)) {
3978
- | $hs.add $postFix( $value);
3979
- | ${ev.value}. $setter;
3980
- | $pos++;
3981
- | }
3982
- | }
3911
+ | $processArray
3983
3912
| }
3984
3913
|}
3914
+ | ${buildResultArray(builder, ev.value, size, nullElementIndex)}
3985
3915
""" .stripMargin
3986
- } else {
3987
- val arrayUnion = classOf [ArrayUnion ].getName
3988
- val et = ctx.addReferenceObj(" elementTypeUnion" , elementType)
3989
- val order = ctx.addReferenceObj(" orderingUnion" , ordering)
3990
- val method = " unionOrdering"
3991
- s " ${ev.value} = $arrayUnion$$ .MODULE $$ . $method( $array1, $array2, $et, $order); "
3992
- }
3993
- })
3916
+ })
3917
+ } else {
3918
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
3919
+ val expr = ctx.addReferenceObj(" arrayUnionExpr" , this )
3920
+ s " ${ev.value} = (ArrayData) $expr.nullSafeEval( $array1, $array2); "
3921
+ })
3922
+ }
3994
3923
}
3995
3924
3996
3925
override def prettyName : String = " array_union"
@@ -4154,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
4154
4083
}
4155
4084
4156
4085
override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
4157
- val arrayData = classOf [ArrayData ].getName
4158
4086
val i = ctx.freshName(" i" )
4159
4087
val value = ctx.freshName(" value" )
4160
4088
val size = ctx.freshName(" size" )
@@ -4268,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
4268
4196
} else {
4269
4197
nullSafeCodeGen(ctx, ev, (array1, array2) => {
4270
4198
val expr = ctx.addReferenceObj(" arrayIntersectExpr" , this )
4271
- s " ${ev.value} = ( $arrayData ) $expr.nullSafeEval( $array1, $array2); "
4199
+ s " ${ev.value} = (ArrayData ) $expr.nullSafeEval( $array1, $array2); "
4272
4200
})
4273
4201
}
4274
4202
}
@@ -4387,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
4387
4315
}
4388
4316
4389
4317
override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
4390
- val arrayData = classOf [ArrayData ].getName
4391
4318
val i = ctx.freshName(" i" )
4392
4319
val value = ctx.freshName(" value" )
4393
4320
val size = ctx.freshName(" size" )
@@ -4490,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
4490
4417
} else {
4491
4418
nullSafeCodeGen(ctx, ev, (array1, array2) => {
4492
4419
val expr = ctx.addReferenceObj(" arrayExceptExpr" , this )
4493
- s " ${ev.value} = ( $arrayData ) $expr.nullSafeEval( $array1, $array2); "
4420
+ s " ${ev.value} = (ArrayData ) $expr.nullSafeEval( $array1, $array2); "
4494
4421
})
4495
4422
}
4496
4423
}
0 commit comments