@@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
39
39
import org .apache .spark .sql .catalyst .InternalRow
40
40
import org .apache .spark .sql .catalyst .expressions ._
41
41
import org .apache .spark .sql .catalyst .expressions .codegen .Block ._
42
- import org .apache .spark .sql .catalyst .util .{ArrayData , MapData }
42
+ import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData , MapData }
43
43
import org .apache .spark .sql .internal .SQLConf
44
44
import org .apache .spark .sql .types ._
45
45
import org .apache .spark .unsafe .Platform
@@ -746,73 +746,6 @@ class CodegenContext {
746
746
""" .stripMargin
747
747
}
748
748
749
- /**
750
- * Generates code creating a [[UnsafeArrayData ]].
751
- *
752
- * @param arrayName name of the array to create
753
- * @param numElements code representing the number of elements the array should contain
754
- * @param elementType data type of the elements in the array
755
- * @param additionalErrorMessage string to include in the error message
756
- */
757
- def createUnsafeArray (
758
- arrayName : String ,
759
- numElements : String ,
760
- elementType : DataType ,
761
- additionalErrorMessage : String ): String = {
762
- val arraySize = freshName(" size" )
763
- val arrayBytes = freshName(" arrayBytes" )
764
-
765
- s """
766
- |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
767
- | $numElements,
768
- | ${elementType.defaultSize});
769
- |if ( $arraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
770
- | throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
771
- | " bytes of data due to exceeding the limit " +
772
- | " ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH } bytes for UnsafeArrayData." +
773
- | " $additionalErrorMessage");
774
- |}
775
- |byte[] $arrayBytes = new byte[(int) $arraySize];
776
- |UnsafeArrayData $arrayName = new UnsafeArrayData();
777
- |Platform.putLong( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, $numElements);
778
- | $arrayName.pointTo( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, (int) $arraySize);
779
- """ .stripMargin
780
- }
781
-
782
- /**
783
- * Generates code creating a [[UnsafeArrayData ]]. The generated code executes
784
- * a provided fallback when the size of backing array would exceed the array size limit.
785
- * @param arrayName a name of the array to create
786
- * @param numElements a piece of code representing the number of elements the array should contain
787
- * @param elementSize a size of an element in bytes
788
- * @param bodyCode a function generating code that fills up the [[UnsafeArrayData ]]
789
- * and getting the backing array as a parameter
790
- * @param fallbackCode a piece of code executed when the array size limit is exceeded
791
- */
792
- def createUnsafeArrayWithFallback (
793
- arrayName : String ,
794
- numElements : String ,
795
- elementSize : Int ,
796
- bodyCode : String => String ,
797
- fallbackCode : String ): String = {
798
- val arraySize = freshName(" size" )
799
- val arrayBytes = freshName(" arrayBytes" )
800
- s """
801
- |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
802
- | $numElements,
803
- | $elementSize);
804
- |if ( $arraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
805
- | $fallbackCode
806
- |} else {
807
- | final byte[] $arrayBytes = new byte[(int) $arraySize];
808
- | UnsafeArrayData $arrayName = new UnsafeArrayData();
809
- | Platform.putLong( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, $numElements);
810
- | $arrayName.pointTo( $arrayBytes, ${Platform .BYTE_ARRAY_OFFSET }, (int) $arraySize);
811
- | ${bodyCode(arrayBytes)}
812
- |}
813
- """ .stripMargin
814
- }
815
-
816
749
/**
817
750
* Generates code to do null safe execution, i.e. only execute the code when the input is not
818
751
* null by adding null check if necessary.
@@ -1490,6 +1423,59 @@ object CodeGenerator extends Logging {
1490
1423
}
1491
1424
}
1492
1425
1426
+ /**
1427
+ * Generates code creating a [[UnsafeArrayData ]] or [[GenericArrayData ]] based on
1428
+ * given parameters.
1429
+ *
1430
+ * @param arrayName name of the array to create
1431
+ * @param elementType data type of the elements in source array
1432
+ * @param numElements code representing the number of elements the array should contain
1433
+ * @param additionalErrorMessage string to include in the error message
1434
+ *
1435
+ * @return code representing the allocation of [[ArrayData ]]
1436
+ */
1437
+ def createArrayData (
1438
+ arrayName : String ,
1439
+ elementType : DataType ,
1440
+ numElements : String ,
1441
+ additionalErrorMessage : String ): String = {
1442
+ val elementSize = if (CodeGenerator .isPrimitiveType(elementType)) {
1443
+ elementType.defaultSize
1444
+ } else {
1445
+ - 1
1446
+ }
1447
+ s """
1448
+ |ArrayData $arrayName = ArrayData.allocateArrayData(
1449
+ | $elementSize, $numElements, " $additionalErrorMessage");
1450
+ """ .stripMargin
1451
+ }
1452
+
1453
+ /**
1454
+ * Generates assignment code for an [[ArrayData ]]
1455
+ *
1456
+ * @param dstArray name of the array to be assigned
1457
+ * @param elementType data type of the elements in destination and source arrays
1458
+ * @param srcArray name of the array to be read
1459
+ * @param needNullCheck value which shows whether a nullcheck is required for the returning
1460
+ * assignment
1461
+ * @param dstArrayIndex an index variable to access each element of destination array
1462
+ * @param srcArrayIndex an index variable to access each element of source array
1463
+ *
1464
+ * @return code representing an assignment to each element of the [[ArrayData ]], which requires
1465
+ * a pair of destination and source loop index variables
1466
+ */
1467
+ def createArrayAssignment (
1468
+ dstArray : String ,
1469
+ elementType : DataType ,
1470
+ srcArray : String ,
1471
+ dstArrayIndex : String ,
1472
+ srcArrayIndex : String ,
1473
+ needNullCheck : Boolean ): String = {
1474
+ CodeGenerator .setArrayElement(dstArray, elementType, dstArrayIndex,
1475
+ CodeGenerator .getValue(srcArray, elementType, srcArrayIndex),
1476
+ if (needNullCheck) Some (s " $srcArray.isNullAt( $srcArrayIndex) " ) else None )
1477
+ }
1478
+
1493
1479
/**
1494
1480
* Returns the code to update a column in Row for a given DataType.
1495
1481
*/
@@ -1558,6 +1544,34 @@ object CodeGenerator extends Logging {
1558
1544
}
1559
1545
}
1560
1546
1547
+ /**
1548
+ * Generates code of setter for an [[ArrayData ]].
1549
+ */
1550
+ def setArrayElement (
1551
+ array : String ,
1552
+ elementType : DataType ,
1553
+ i : String ,
1554
+ value : String ,
1555
+ isNull : Option [String ] = None ): String = {
1556
+ val isPrimitiveType = CodeGenerator .isPrimitiveType(elementType)
1557
+ val setFunc = if (isPrimitiveType) {
1558
+ s " set ${CodeGenerator .primitiveTypeName(elementType)}"
1559
+ } else {
1560
+ " update"
1561
+ }
1562
+ if (isNull.isDefined && isPrimitiveType) {
1563
+ s """
1564
+ |if ( ${isNull.get}) {
1565
+ | $array.setNullAt( $i);
1566
+ |} else {
1567
+ | $array. $setFunc( $i, $value);
1568
+ |}
1569
+ """ .stripMargin
1570
+ } else {
1571
+ s " $array. $setFunc( $i, $value); "
1572
+ }
1573
+ }
1574
+
1561
1575
/**
1562
1576
* Returns the specialized code to set a given value in a column vector for a given `DataType`
1563
1577
* that could potentially be nullable.
0 commit comments