@@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
23
23
import org .apache .spark .sql .catalyst .expressions .codegen ._
24
24
import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData , MapData , TypeUtils }
25
25
import org .apache .spark .sql .types ._
26
- import org .apache .spark .unsafe .types .UTF8String
26
+ import org .apache .spark .unsafe .Platform
27
+ import org .apache .spark .unsafe .array .ByteArrayMethods
28
+ import org .apache .spark .unsafe .types .{ByteArray , UTF8String }
27
29
28
30
/**
29
31
* Given an array or map, returns its size. Returns -1 if null.
@@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
665
667
666
668
override def prettyName : String = " element_at"
667
669
}
670
+
671
+ /**
672
+ * Concatenates multiple input columns together into a single column.
673
+ * The function works with strings, binary and compatible array columns.
674
+ */
675
+ @ ExpressionDescription (
676
+ usage = " _FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN." ,
677
+ examples = """
678
+ Examples:
679
+ > SELECT _FUNC_('Spark', 'SQL');
680
+ SparkSQL
681
+ > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
682
+ | [1,2,3,4,5,6]
683
+ """ )
684
+ case class Concat (children : Seq [Expression ]) extends Expression {
685
+
686
+ private val MAX_ARRAY_LENGTH : Int = ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH
687
+
688
+ val allowedTypes = Seq (StringType , BinaryType , ArrayType )
689
+
690
+ override def checkInputDataTypes (): TypeCheckResult = {
691
+ if (children.isEmpty) {
692
+ TypeCheckResult .TypeCheckSuccess
693
+ } else {
694
+ val childTypes = children.map(_.dataType)
695
+ if (childTypes.exists(tpe => ! allowedTypes.exists(_.acceptsType(tpe)))) {
696
+ return TypeCheckResult .TypeCheckFailure (
697
+ s " input to function $prettyName should have been StringType, BinaryType or ArrayType, " +
698
+ s " but it's " + childTypes.map(_.simpleString).mkString(" [" , " , " , " ]" ))
699
+ }
700
+ TypeUtils .checkForSameTypeInputExpr(childTypes, s " function $prettyName" )
701
+ }
702
+ }
703
+
704
+ override def dataType : DataType = children.map(_.dataType).headOption.getOrElse(StringType )
705
+
706
+ lazy val javaType : String = CodeGenerator .javaType(dataType)
707
+
708
+ override def nullable : Boolean = children.exists(_.nullable)
709
+
710
+ override def foldable : Boolean = children.forall(_.foldable)
711
+
712
+ override def eval (input : InternalRow ): Any = dataType match {
713
+ case BinaryType =>
714
+ val inputs = children.map(_.eval(input).asInstanceOf [Array [Byte ]])
715
+ ByteArray .concat(inputs : _* )
716
+ case StringType =>
717
+ val inputs = children.map(_.eval(input).asInstanceOf [UTF8String ])
718
+ UTF8String .concat(inputs : _* )
719
+ case ArrayType (elementType, _) =>
720
+ val inputs = children.toStream.map(_.eval(input))
721
+ if (inputs.contains(null )) {
722
+ null
723
+ } else {
724
+ val arrayData = inputs.map(_.asInstanceOf [ArrayData ])
725
+ val numberOfElements = arrayData.foldLeft(0L )((sum, ad) => sum + ad.numElements())
726
+ if (numberOfElements > MAX_ARRAY_LENGTH ) {
727
+ throw new RuntimeException (s " Unsuccessful try to concat arrays with $numberOfElements" +
728
+ s " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH. " )
729
+ }
730
+ val finalData = new Array [AnyRef ](numberOfElements.toInt)
731
+ var position = 0
732
+ for (ad <- arrayData) {
733
+ val arr = ad.toObjectArray(elementType)
734
+ Array .copy(arr, 0 , finalData, position, arr.length)
735
+ position += arr.length
736
+ }
737
+ new GenericArrayData (finalData)
738
+ }
739
+ }
740
+
741
+ override protected def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
742
+ val evals = children.map(_.genCode(ctx))
743
+ val args = ctx.freshName(" args" )
744
+
745
+ val inputs = evals.zipWithIndex.map { case (eval, index) =>
746
+ s """
747
+ ${eval.code}
748
+ if (! ${eval.isNull}) {
749
+ $args[ $index] = ${eval.value};
750
+ }
751
+ """
752
+ }
753
+
754
+ val (concatenator, initCode) = dataType match {
755
+ case BinaryType =>
756
+ (classOf [ByteArray ].getName, s " byte[][] $args = new byte[ ${evals.length}][]; " )
757
+ case StringType =>
758
+ (" UTF8String" , s " UTF8String[] $args = new UTF8String[ ${evals.length}]; " )
759
+ case ArrayType (elementType, _) =>
760
+ val arrayConcatClass = if (CodeGenerator .isPrimitiveType(elementType)) {
761
+ genCodeForPrimitiveArrays(ctx, elementType)
762
+ } else {
763
+ genCodeForNonPrimitiveArrays(ctx, elementType)
764
+ }
765
+ (arrayConcatClass, s " ArrayData[] $args = new ArrayData[ ${evals.length}]; " )
766
+ }
767
+ val codes = ctx.splitExpressionsWithCurrentInputs(
768
+ expressions = inputs,
769
+ funcName = " valueConcat" ,
770
+ extraArguments = (s " $javaType[] " , args) :: Nil )
771
+ ev.copy(s """
772
+ $initCode
773
+ $codes
774
+ $javaType ${ev.value} = $concatenator.concat( $args);
775
+ boolean ${ev.isNull} = ${ev.value} == null;
776
+ """ )
777
+ }
778
+
779
+ private def genCodeForNumberOfElements (ctx : CodegenContext ) : (String , String ) = {
780
+ val numElements = ctx.freshName(" numElements" )
781
+ val code = s """
782
+ |long $numElements = 0L;
783
+ |for (int z = 0; z < ${children.length}; z++) {
784
+ | $numElements += args[z].numElements();
785
+ |}
786
+ |if ( $numElements > $MAX_ARRAY_LENGTH) {
787
+ | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
788
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
789
+ |}
790
+ """ .stripMargin
791
+
792
+ (code, numElements)
793
+ }
794
+
795
+ private def nullArgumentProtection () : String = {
796
+ if (nullable) {
797
+ s """
798
+ |for (int z = 0; z < ${children.length}; z++) {
799
+ | if (args[z] == null) return null;
800
+ |}
801
+ """ .stripMargin
802
+ } else {
803
+ " "
804
+ }
805
+ }
806
+
807
+ private def genCodeForPrimitiveArrays (ctx : CodegenContext , elementType : DataType ): String = {
808
+ val arrayName = ctx.freshName(" array" )
809
+ val arraySizeName = ctx.freshName(" size" )
810
+ val counter = ctx.freshName(" counter" )
811
+ val arrayData = ctx.freshName(" arrayData" )
812
+
813
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
814
+
815
+ val unsafeArraySizeInBytes = s """
816
+ |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
817
+ | $numElemName,
818
+ | ${elementType.defaultSize});
819
+ |if ( $arraySizeName > $MAX_ARRAY_LENGTH) {
820
+ | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
821
+ | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
822
+ | " for UnsafeArrayData.");
823
+ |}
824
+ """ .stripMargin
825
+ val baseOffset = Platform .BYTE_ARRAY_OFFSET
826
+ val primitiveValueTypeName = CodeGenerator .primitiveTypeName(elementType)
827
+
828
+ s """
829
+ |new Object() {
830
+ | public ArrayData concat( $javaType[] args) {
831
+ | ${nullArgumentProtection()}
832
+ | $numElemCode
833
+ | $unsafeArraySizeInBytes
834
+ | byte[] $arrayName = new byte[(int) $arraySizeName];
835
+ | UnsafeArrayData $arrayData = new UnsafeArrayData();
836
+ | Platform.putLong( $arrayName, $baseOffset, $numElemName);
837
+ | $arrayData.pointTo( $arrayName, $baseOffset, (int) $arraySizeName);
838
+ | int $counter = 0;
839
+ | for (int y = 0; y < ${children.length}; y++) {
840
+ | for (int z = 0; z < args[y].numElements(); z++) {
841
+ | if (args[y].isNullAt(z)) {
842
+ | $arrayData.setNullAt( $counter);
843
+ | } else {
844
+ | $arrayData.set $primitiveValueTypeName(
845
+ | $counter,
846
+ | ${CodeGenerator .getValue(s " args[y] " , elementType, " z" )}
847
+ | );
848
+ | }
849
+ | $counter++;
850
+ | }
851
+ | }
852
+ | return $arrayData;
853
+ | }
854
+ |} """ .stripMargin.stripPrefix(" \n " )
855
+ }
856
+
857
+ private def genCodeForNonPrimitiveArrays (ctx : CodegenContext , elementType : DataType ): String = {
858
+ val genericArrayClass = classOf [GenericArrayData ].getName
859
+ val arrayData = ctx.freshName(" arrayObjects" )
860
+ val counter = ctx.freshName(" counter" )
861
+
862
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
863
+
864
+ s """
865
+ |new Object() {
866
+ | public ArrayData concat( $javaType[] args) {
867
+ | ${nullArgumentProtection()}
868
+ | $numElemCode
869
+ | Object[] $arrayData = new Object[(int) $numElemName];
870
+ | int $counter = 0;
871
+ | for (int y = 0; y < ${children.length}; y++) {
872
+ | for (int z = 0; z < args[y].numElements(); z++) {
873
+ | $arrayData[ $counter] = ${CodeGenerator .getValue(s " args[y] " , elementType, " z" )};
874
+ | $counter++;
875
+ | }
876
+ | }
877
+ | return new $genericArrayClass( $arrayData);
878
+ | }
879
+ |} """ .stripMargin.stripPrefix(" \n " )
880
+ }
881
+
882
+ override def toString : String = s " concat( ${children.mkString(" , " )}) "
883
+
884
+ override def sql : String = s " concat( ${children.map(_.sql).mkString(" , " )}) "
885
+ }
0 commit comments