Skip to content

Commit e319ac9

Browse files
kiszkueshin
authored andcommitted
[SPARK-24962][SQL] Refactor CodeGenerator.createUnsafeArray, ArraySetLike, and ArrayDistinct
## What changes were proposed in this pull request? This PR integrates handling of `UnsafeArrayData` and `GenericArrayData` into one. The current `CodeGenerator.createUnsafeArray` handles only allocation of `UnsafeArrayData`. This PR introduces a new method `createArrayData` that returns a code to allocate `UnsafeArrayData` or `GenericArrayData` and to assign a value into the allocated array. This PR also reduce the size of generated code by calling a runtime helper. This PR replaced `createArrayData` with `createUnsafeArray`. This PR also refactor `ArraySetLike` that can be used for `ArrayDistinct`, too. This PR also refactors`ArrayDistinct` to use `ArraryBuilder`. ## How was this patch tested? Existing tests Closes apache#21912 from kiszk/SPARK-24962. Lead-authored-by: Kazuaki Ishizaki <[email protected]> Co-authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4cb2ff9 commit e319ac9

File tree

4 files changed

+464
-661
lines changed

4 files changed

+464
-661
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,27 @@ public static UnsafeArrayData fromPrimitiveArray(
473473
return result;
474474
}
475475

476-
public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
477-
return fromPrimitiveArray(null, offset, length, elementSize);
476+
public static UnsafeArrayData createFreshArray(int length, int elementSize) {
477+
final long headerInBytes = calculateHeaderPortionInBytes(length);
478+
final long valueRegionInBytes = (long)elementSize * length;
479+
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
480+
if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
481+
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
482+
"it's too big.");
483+
}
484+
485+
final long[] data = new long[(int)totalSizeInLongs];
486+
487+
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
488+
489+
UnsafeArrayData result = new UnsafeArrayData();
490+
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
491+
return result;
478492
}
479493

480-
public static boolean shouldUseGenericArrayData(int elementSize, int length) {
494+
public static boolean shouldUseGenericArrayData(int elementSize, long length) {
481495
final long headerInBytes = calculateHeaderPortionInBytes(length);
482-
final long valueRegionInBytes = (long)elementSize * length;
496+
final long valueRegionInBytes = elementSize * length;
483497
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
484498
return totalSizeInLongs > Integer.MAX_VALUE / 8;
485499
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
3939
import org.apache.spark.sql.catalyst.InternalRow
4040
import org.apache.spark.sql.catalyst.expressions._
4141
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
42-
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
42+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
4343
import org.apache.spark.sql.internal.SQLConf
4444
import org.apache.spark.sql.types._
4545
import org.apache.spark.unsafe.Platform
@@ -746,73 +746,6 @@ class CodegenContext {
746746
""".stripMargin
747747
}
748748

749-
/**
750-
* Generates code creating a [[UnsafeArrayData]].
751-
*
752-
* @param arrayName name of the array to create
753-
* @param numElements code representing the number of elements the array should contain
754-
* @param elementType data type of the elements in the array
755-
* @param additionalErrorMessage string to include in the error message
756-
*/
757-
def createUnsafeArray(
758-
arrayName: String,
759-
numElements: String,
760-
elementType: DataType,
761-
additionalErrorMessage: String): String = {
762-
val arraySize = freshName("size")
763-
val arrayBytes = freshName("arrayBytes")
764-
765-
s"""
766-
|long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
767-
| $numElements,
768-
| ${elementType.defaultSize});
769-
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
770-
| throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
771-
| " bytes of data due to exceeding the limit " +
772-
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." +
773-
| "$additionalErrorMessage");
774-
|}
775-
|byte[] $arrayBytes = new byte[(int)$arraySize];
776-
|UnsafeArrayData $arrayName = new UnsafeArrayData();
777-
|Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
778-
|$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
779-
""".stripMargin
780-
}
781-
782-
/**
783-
* Generates code creating a [[UnsafeArrayData]]. The generated code executes
784-
* a provided fallback when the size of backing array would exceed the array size limit.
785-
* @param arrayName a name of the array to create
786-
* @param numElements a piece of code representing the number of elements the array should contain
787-
* @param elementSize a size of an element in bytes
788-
* @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
789-
* and getting the backing array as a parameter
790-
* @param fallbackCode a piece of code executed when the array size limit is exceeded
791-
*/
792-
def createUnsafeArrayWithFallback(
793-
arrayName: String,
794-
numElements: String,
795-
elementSize: Int,
796-
bodyCode: String => String,
797-
fallbackCode: String): String = {
798-
val arraySize = freshName("size")
799-
val arrayBytes = freshName("arrayBytes")
800-
s"""
801-
|final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
802-
| $numElements,
803-
| $elementSize);
804-
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
805-
| $fallbackCode
806-
|} else {
807-
| final byte[] $arrayBytes = new byte[(int)$arraySize];
808-
| UnsafeArrayData $arrayName = new UnsafeArrayData();
809-
| Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
810-
| $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
811-
| ${bodyCode(arrayBytes)}
812-
|}
813-
""".stripMargin
814-
}
815-
816749
/**
817750
* Generates code to do null safe execution, i.e. only execute the code when the input is not
818751
* null by adding null check if necessary.
@@ -1490,6 +1423,59 @@ object CodeGenerator extends Logging {
14901423
}
14911424
}
14921425

1426+
/**
1427+
* Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] based on
1428+
* given parameters.
1429+
*
1430+
* @param arrayName name of the array to create
1431+
* @param elementType data type of the elements in source array
1432+
* @param numElements code representing the number of elements the array should contain
1433+
* @param additionalErrorMessage string to include in the error message
1434+
*
1435+
* @return code representing the allocation of [[ArrayData]]
1436+
*/
1437+
def createArrayData(
1438+
arrayName: String,
1439+
elementType: DataType,
1440+
numElements: String,
1441+
additionalErrorMessage: String): String = {
1442+
val elementSize = if (CodeGenerator.isPrimitiveType(elementType)) {
1443+
elementType.defaultSize
1444+
} else {
1445+
-1
1446+
}
1447+
s"""
1448+
|ArrayData $arrayName = ArrayData.allocateArrayData(
1449+
| $elementSize, $numElements, "$additionalErrorMessage");
1450+
""".stripMargin
1451+
}
1452+
1453+
/**
1454+
* Generates assignment code for an [[ArrayData]]
1455+
*
1456+
* @param dstArray name of the array to be assigned
1457+
* @param elementType data type of the elements in destination and source arrays
1458+
* @param srcArray name of the array to be read
1459+
* @param needNullCheck value which shows whether a nullcheck is required for the returning
1460+
* assignment
1461+
* @param dstArrayIndex an index variable to access each element of destination array
1462+
* @param srcArrayIndex an index variable to access each element of source array
1463+
*
1464+
* @return code representing an assignment to each element of the [[ArrayData]], which requires
1465+
* a pair of destination and source loop index variables
1466+
*/
1467+
def createArrayAssignment(
1468+
dstArray: String,
1469+
elementType: DataType,
1470+
srcArray: String,
1471+
dstArrayIndex: String,
1472+
srcArrayIndex: String,
1473+
needNullCheck: Boolean): String = {
1474+
CodeGenerator.setArrayElement(dstArray, elementType, dstArrayIndex,
1475+
CodeGenerator.getValue(srcArray, elementType, srcArrayIndex),
1476+
if (needNullCheck) Some(s"$srcArray.isNullAt($srcArrayIndex)") else None)
1477+
}
1478+
14931479
/**
14941480
* Returns the code to update a column in Row for a given DataType.
14951481
*/
@@ -1558,6 +1544,34 @@ object CodeGenerator extends Logging {
15581544
}
15591545
}
15601546

1547+
/**
1548+
* Generates code of setter for an [[ArrayData]].
1549+
*/
1550+
def setArrayElement(
1551+
array: String,
1552+
elementType: DataType,
1553+
i: String,
1554+
value: String,
1555+
isNull: Option[String] = None): String = {
1556+
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
1557+
val setFunc = if (isPrimitiveType) {
1558+
s"set${CodeGenerator.primitiveTypeName(elementType)}"
1559+
} else {
1560+
"update"
1561+
}
1562+
if (isNull.isDefined && isPrimitiveType) {
1563+
s"""
1564+
|if (${isNull.get}) {
1565+
| $array.setNullAt($i);
1566+
|} else {
1567+
| $array.$setFunc($i, $value);
1568+
|}
1569+
""".stripMargin
1570+
} else {
1571+
s"$array.$setFunc($i, $value);"
1572+
}
1573+
}
1574+
15611575
/**
15621576
* Returns the specialized code to set a given value in a column vector for a given `DataType`
15631577
* that could potentially be nullable.

0 commit comments

Comments
 (0)