Skip to content

Commit e6b4660

Browse files
mn-mikkeueshin
authored andcommitted
[SPARK-23736][SQL] Extending the concat function to support array columns
## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke <mrkAha12346github> Closes apache#20858 from mn-mikke/feature/array-api-concat_arrays-to-master.
1 parent b3fde5a commit e6b4660

File tree

13 files changed

+529
-111
lines changed

13 files changed

+529
-111
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) {
3333
}
3434

3535
public static int roundNumberOfBytesToNearestWord(int numBytes) {
36-
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
36+
return (int)roundNumberOfBytesToNearestWord((long)numBytes);
37+
}
38+
39+
public static long roundNumberOfBytesToNearestWord(long numBytes) {
40+
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
3741
if (remainder == 0) {
3842
return numBytes;
3943
} else {

python/pyspark/sql/functions.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,21 +1425,6 @@ def hash(*cols):
14251425
del _name, _doc
14261426

14271427

1428-
@since(1.5)
1429-
@ignore_unicode_prefix
1430-
def concat(*cols):
1431-
"""
1432-
Concatenates multiple input columns together into a single column.
1433-
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
1434-
1435-
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
1436-
>>> df.select(concat(df.s, df.d).alias('s')).collect()
1437-
[Row(s=u'abcd123')]
1438-
"""
1439-
sc = SparkContext._active_spark_context
1440-
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
1441-
1442-
14431428
@since(1.5)
14441429
@ignore_unicode_prefix
14451430
def concat_ws(sep, *cols):
@@ -1845,6 +1830,25 @@ def array_contains(col, value):
18451830
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
18461831

18471832

1833+
@since(1.5)
1834+
@ignore_unicode_prefix
1835+
def concat(*cols):
1836+
"""
1837+
Concatenates multiple input columns together into a single column.
1838+
The function works with strings, binary and compatible array columns.
1839+
1840+
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
1841+
>>> df.select(concat(df.s, df.d).alias('s')).collect()
1842+
[Row(s=u'abcd123')]
1843+
1844+
>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
1845+
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
1846+
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
1847+
"""
1848+
sc = SparkContext._active_spark_context
1849+
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
1850+
1851+
18481852
@since(2.4)
18491853
def array_position(col, value):
18501854
"""

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,19 @@
5656
public final class UnsafeArrayData extends ArrayData {
5757

5858
public static int calculateHeaderPortionInBytes(int numFields) {
59+
return (int)calculateHeaderPortionInBytes((long)numFields);
60+
}
61+
62+
public static long calculateHeaderPortionInBytes(long numFields) {
5963
return 8 + ((numFields + 63)/ 64) * 8;
6064
}
6165

66+
public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) {
67+
long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
68+
ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
69+
return size;
70+
}
71+
6272
private Object baseObject;
6373
private long baseOffset;
6474

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ object FunctionRegistry {
308308
expression[BitLength]("bit_length"),
309309
expression[Length]("char_length"),
310310
expression[Length]("character_length"),
311-
expression[Concat]("concat"),
312311
expression[ConcatWs]("concat_ws"),
313312
expression[Decode]("decode"),
314313
expression[Elt]("elt"),
@@ -413,6 +412,7 @@ object FunctionRegistry {
413412
expression[ArrayMin]("array_min"),
414413
expression[ArrayMax]("array_max"),
415414
expression[Reverse]("reverse"),
415+
expression[Concat]("concat"),
416416
CreateStruct.registryEntry,
417417

418418
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,14 @@ object TypeCoercion {
520520
case None => a
521521
}
522522

523+
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
524+
!haveSameType(children) =>
525+
val types = children.map(_.dataType)
526+
findWiderCommonType(types) match {
527+
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
528+
case None => c
529+
}
530+
523531
case m @ CreateMap(children) if m.keys.length == m.values.length &&
524532
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
525533
val newKeys = if (haveSameType(m.keys)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 219 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
2525
import org.apache.spark.sql.types._
26-
import org.apache.spark.unsafe.types.UTF8String
26+
import org.apache.spark.unsafe.Platform
27+
import org.apache.spark.unsafe.array.ByteArrayMethods
28+
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
2729

2830
/**
2931
* Given an array or map, returns its size. Returns -1 if null.
@@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
665667

666668
override def prettyName: String = "element_at"
667669
}
670+
671+
/**
672+
* Concatenates multiple input columns together into a single column.
673+
* The function works with strings, binary and compatible array columns.
674+
*/
675+
@ExpressionDescription(
676+
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
677+
examples = """
678+
Examples:
679+
> SELECT _FUNC_('Spark', 'SQL');
680+
SparkSQL
681+
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
682+
| [1,2,3,4,5,6]
683+
""")
684+
case class Concat(children: Seq[Expression]) extends Expression {
685+
686+
private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
687+
688+
val allowedTypes = Seq(StringType, BinaryType, ArrayType)
689+
690+
override def checkInputDataTypes(): TypeCheckResult = {
691+
if (children.isEmpty) {
692+
TypeCheckResult.TypeCheckSuccess
693+
} else {
694+
val childTypes = children.map(_.dataType)
695+
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
696+
return TypeCheckResult.TypeCheckFailure(
697+
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
698+
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
699+
}
700+
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
701+
}
702+
}
703+
704+
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
705+
706+
lazy val javaType: String = CodeGenerator.javaType(dataType)
707+
708+
override def nullable: Boolean = children.exists(_.nullable)
709+
710+
override def foldable: Boolean = children.forall(_.foldable)
711+
712+
override def eval(input: InternalRow): Any = dataType match {
713+
case BinaryType =>
714+
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
715+
ByteArray.concat(inputs: _*)
716+
case StringType =>
717+
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
718+
UTF8String.concat(inputs : _*)
719+
case ArrayType(elementType, _) =>
720+
val inputs = children.toStream.map(_.eval(input))
721+
if (inputs.contains(null)) {
722+
null
723+
} else {
724+
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
725+
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
726+
if (numberOfElements > MAX_ARRAY_LENGTH) {
727+
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
728+
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
729+
}
730+
val finalData = new Array[AnyRef](numberOfElements.toInt)
731+
var position = 0
732+
for(ad <- arrayData) {
733+
val arr = ad.toObjectArray(elementType)
734+
Array.copy(arr, 0, finalData, position, arr.length)
735+
position += arr.length
736+
}
737+
new GenericArrayData(finalData)
738+
}
739+
}
740+
741+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
742+
val evals = children.map(_.genCode(ctx))
743+
val args = ctx.freshName("args")
744+
745+
val inputs = evals.zipWithIndex.map { case (eval, index) =>
746+
s"""
747+
${eval.code}
748+
if (!${eval.isNull}) {
749+
$args[$index] = ${eval.value};
750+
}
751+
"""
752+
}
753+
754+
val (concatenator, initCode) = dataType match {
755+
case BinaryType =>
756+
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
757+
case StringType =>
758+
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
759+
case ArrayType(elementType, _) =>
760+
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
761+
genCodeForPrimitiveArrays(ctx, elementType)
762+
} else {
763+
genCodeForNonPrimitiveArrays(ctx, elementType)
764+
}
765+
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
766+
}
767+
val codes = ctx.splitExpressionsWithCurrentInputs(
768+
expressions = inputs,
769+
funcName = "valueConcat",
770+
extraArguments = (s"$javaType[]", args) :: Nil)
771+
ev.copy(s"""
772+
$initCode
773+
$codes
774+
$javaType ${ev.value} = $concatenator.concat($args);
775+
boolean ${ev.isNull} = ${ev.value} == null;
776+
""")
777+
}
778+
779+
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
780+
val numElements = ctx.freshName("numElements")
781+
val code = s"""
782+
|long $numElements = 0L;
783+
|for (int z = 0; z < ${children.length}; z++) {
784+
| $numElements += args[z].numElements();
785+
|}
786+
|if ($numElements > $MAX_ARRAY_LENGTH) {
787+
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
788+
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
789+
|}
790+
""".stripMargin
791+
792+
(code, numElements)
793+
}
794+
795+
private def nullArgumentProtection() : String = {
796+
if (nullable) {
797+
s"""
798+
|for (int z = 0; z < ${children.length}; z++) {
799+
| if (args[z] == null) return null;
800+
|}
801+
""".stripMargin
802+
} else {
803+
""
804+
}
805+
}
806+
807+
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
808+
val arrayName = ctx.freshName("array")
809+
val arraySizeName = ctx.freshName("size")
810+
val counter = ctx.freshName("counter")
811+
val arrayData = ctx.freshName("arrayData")
812+
813+
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
814+
815+
val unsafeArraySizeInBytes = s"""
816+
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
817+
| $numElemName,
818+
| ${elementType.defaultSize});
819+
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
820+
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
821+
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
822+
| " for UnsafeArrayData.");
823+
|}
824+
""".stripMargin
825+
val baseOffset = Platform.BYTE_ARRAY_OFFSET
826+
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
827+
828+
s"""
829+
|new Object() {
830+
| public ArrayData concat($javaType[] args) {
831+
| ${nullArgumentProtection()}
832+
| $numElemCode
833+
| $unsafeArraySizeInBytes
834+
| byte[] $arrayName = new byte[(int)$arraySizeName];
835+
| UnsafeArrayData $arrayData = new UnsafeArrayData();
836+
| Platform.putLong($arrayName, $baseOffset, $numElemName);
837+
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
838+
| int $counter = 0;
839+
| for (int y = 0; y < ${children.length}; y++) {
840+
| for (int z = 0; z < args[y].numElements(); z++) {
841+
| if (args[y].isNullAt(z)) {
842+
| $arrayData.setNullAt($counter);
843+
| } else {
844+
| $arrayData.set$primitiveValueTypeName(
845+
| $counter,
846+
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
847+
| );
848+
| }
849+
| $counter++;
850+
| }
851+
| }
852+
| return $arrayData;
853+
| }
854+
|}""".stripMargin.stripPrefix("\n")
855+
}
856+
857+
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
858+
val genericArrayClass = classOf[GenericArrayData].getName
859+
val arrayData = ctx.freshName("arrayObjects")
860+
val counter = ctx.freshName("counter")
861+
862+
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
863+
864+
s"""
865+
|new Object() {
866+
| public ArrayData concat($javaType[] args) {
867+
| ${nullArgumentProtection()}
868+
| $numElemCode
869+
| Object[] $arrayData = new Object[(int)$numElemName];
870+
| int $counter = 0;
871+
| for (int y = 0; y < ${children.length}; y++) {
872+
| for (int z = 0; z < args[y].numElements(); z++) {
873+
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
874+
| $counter++;
875+
| }
876+
| }
877+
| return new $genericArrayClass($arrayData);
878+
| }
879+
|}""".stripMargin.stripPrefix("\n")
880+
}
881+
882+
override def toString: String = s"concat(${children.mkString(", ")})"
883+
884+
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
885+
}

0 commit comments

Comments
 (0)