@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
25
25
import org .apache .spark .sql .catalyst .expressions .ArraySortLike .NullOrder
26
26
import org .apache .spark .sql .catalyst .expressions .codegen ._
27
27
import org .apache .spark .sql .catalyst .expressions .codegen .Block ._
28
- import org .apache .spark .sql .catalyst .util .{ ArrayData , GenericArrayData , MapData , TypeUtils }
28
+ import org .apache .spark .sql .catalyst .util ._
29
29
import org .apache .spark .sql .internal .SQLConf
30
30
import org .apache .spark .sql .types ._
31
31
import org .apache .spark .unsafe .Platform
@@ -475,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
475
475
override def prettyName : String = " map_entries"
476
476
}
477
477
478
+ /**
479
+ * Returns a map created from the given array of entries.
480
+ */
481
+ @ ExpressionDescription (
482
+ usage = " _FUNC_(arrayOfEntries) - Returns a map created from the given array of entries." ,
483
+ examples = """
484
+ Examples:
485
+ > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
486
+ {1:"a",2:"b"}
487
+ """ ,
488
+ since = " 2.4.0" )
489
+ case class MapFromEntries (child : Expression ) extends UnaryExpression {
490
+
491
+ @ transient
492
+ private lazy val dataTypeDetails : Option [(MapType , Boolean , Boolean )] = child.dataType match {
493
+ case ArrayType (
494
+ StructType (Array (
495
+ StructField (_, keyType, keyNullable, _),
496
+ StructField (_, valueType, valueNullable, _))),
497
+ containsNull) => Some ((MapType (keyType, valueType, valueNullable), keyNullable, containsNull))
498
+ case _ => None
499
+ }
500
+
501
+ private def nullEntries : Boolean = dataTypeDetails.get._3
502
+
503
+ override def nullable : Boolean = child.nullable || nullEntries
504
+
505
+ override def dataType : MapType = dataTypeDetails.get._1
506
+
507
+ override def checkInputDataTypes (): TypeCheckResult = dataTypeDetails match {
508
+ case Some (_) => TypeCheckResult .TypeCheckSuccess
509
+ case None => TypeCheckResult .TypeCheckFailure (s " ' ${child.sql}' is of " +
510
+ s " ${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs. " )
511
+ }
512
+
513
+ override protected def nullSafeEval (input : Any ): Any = {
514
+ val arrayData = input.asInstanceOf [ArrayData ]
515
+ val numEntries = arrayData.numElements()
516
+ var i = 0
517
+ if (nullEntries) {
518
+ while (i < numEntries) {
519
+ if (arrayData.isNullAt(i)) return null
520
+ i += 1
521
+ }
522
+ }
523
+ val keyArray = new Array [AnyRef ](numEntries)
524
+ val valueArray = new Array [AnyRef ](numEntries)
525
+ i = 0
526
+ while (i < numEntries) {
527
+ val entry = arrayData.getStruct(i, 2 )
528
+ val key = entry.get(0 , dataType.keyType)
529
+ if (key == null ) {
530
+ throw new RuntimeException (" The first field from a struct (key) can't be null." )
531
+ }
532
+ keyArray.update(i, key)
533
+ val value = entry.get(1 , dataType.valueType)
534
+ valueArray.update(i, value)
535
+ i += 1
536
+ }
537
+ ArrayBasedMapData (keyArray, valueArray)
538
+ }
539
+
540
+ override protected def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
541
+ nullSafeCodeGen(ctx, ev, c => {
542
+ val numEntries = ctx.freshName(" numEntries" )
543
+ val isKeyPrimitive = CodeGenerator .isPrimitiveType(dataType.keyType)
544
+ val isValuePrimitive = CodeGenerator .isPrimitiveType(dataType.valueType)
545
+ val code = if (isKeyPrimitive && isValuePrimitive) {
546
+ genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
547
+ } else {
548
+ genCodeForAnyElements(ctx, c, ev.value, numEntries)
549
+ }
550
+ ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
551
+ s """
552
+ |final int $numEntries = $c.numElements();
553
+ | $code
554
+ """ .stripMargin
555
+ }
556
+ })
557
+ }
558
+
559
+ private def genCodeForAssignmentLoop (
560
+ ctx : CodegenContext ,
561
+ childVariable : String ,
562
+ mapData : String ,
563
+ numEntries : String ,
564
+ keyAssignment : (String , String ) => String ,
565
+ valueAssignment : (String , String ) => String ): String = {
566
+ val entry = ctx.freshName(" entry" )
567
+ val i = ctx.freshName(" idx" )
568
+
569
+ val nullKeyCheck = if (dataTypeDetails.get._2) {
570
+ s """
571
+ |if ( $entry.isNullAt(0)) {
572
+ | throw new RuntimeException("The first field from a struct (key) can't be null.");
573
+ |}
574
+ """ .stripMargin
575
+ } else {
576
+ " "
577
+ }
578
+
579
+ s """
580
+ |for (int $i = 0; $i < $numEntries; $i++) {
581
+ | InternalRow $entry = $childVariable.getStruct( $i, 2);
582
+ | $nullKeyCheck
583
+ | ${keyAssignment(CodeGenerator .getValue(entry, dataType.keyType, " 0" ), i)}
584
+ | ${valueAssignment(entry, i)}
585
+ |}
586
+ """ .stripMargin
587
+ }
588
+
589
+ private def genCodeForPrimitiveElements (
590
+ ctx : CodegenContext ,
591
+ childVariable : String ,
592
+ mapData : String ,
593
+ numEntries : String ): String = {
594
+ val byteArraySize = ctx.freshName(" byteArraySize" )
595
+ val keySectionSize = ctx.freshName(" keySectionSize" )
596
+ val valueSectionSize = ctx.freshName(" valueSectionSize" )
597
+ val data = ctx.freshName(" byteArray" )
598
+ val unsafeMapData = ctx.freshName(" unsafeMapData" )
599
+ val keyArrayData = ctx.freshName(" keyArrayData" )
600
+ val valueArrayData = ctx.freshName(" valueArrayData" )
601
+
602
+ val baseOffset = Platform .BYTE_ARRAY_OFFSET
603
+ val keySize = dataType.keyType.defaultSize
604
+ val valueSize = dataType.valueType.defaultSize
605
+ val kByteSize = s " UnsafeArrayData.calculateSizeOfUnderlyingByteArray( $numEntries, $keySize) "
606
+ val vByteSize = s " UnsafeArrayData.calculateSizeOfUnderlyingByteArray( $numEntries, $valueSize) "
607
+ val keyTypeName = CodeGenerator .primitiveTypeName(dataType.keyType)
608
+ val valueTypeName = CodeGenerator .primitiveTypeName(dataType.valueType)
609
+
610
+ val keyAssignment = (key : String , idx : String ) => s " $keyArrayData.set $keyTypeName( $idx, $key); "
611
+ val valueAssignment = (entry : String , idx : String ) => {
612
+ val value = CodeGenerator .getValue(entry, dataType.valueType, " 1" )
613
+ val valueNullUnsafeAssignment = s " $valueArrayData.set $valueTypeName( $idx, $value); "
614
+ if (dataType.valueContainsNull) {
615
+ s """
616
+ |if ( $entry.isNullAt(1)) {
617
+ | $valueArrayData.setNullAt( $idx);
618
+ |} else {
619
+ | $valueNullUnsafeAssignment
620
+ |}
621
+ """ .stripMargin
622
+ } else {
623
+ valueNullUnsafeAssignment
624
+ }
625
+ }
626
+ val assignmentLoop = genCodeForAssignmentLoop(
627
+ ctx,
628
+ childVariable,
629
+ mapData,
630
+ numEntries,
631
+ keyAssignment,
632
+ valueAssignment
633
+ )
634
+
635
+ s """
636
+ |final long $keySectionSize = $kByteSize;
637
+ |final long $valueSectionSize = $vByteSize;
638
+ |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
639
+ |if ( $byteArraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
640
+ | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
641
+ |} else {
642
+ | final byte[] $data = new byte[(int) $byteArraySize];
643
+ | UnsafeMapData $unsafeMapData = new UnsafeMapData();
644
+ | Platform.putLong( $data, $baseOffset, $keySectionSize);
645
+ | Platform.putLong( $data, ${baseOffset + 8 }, $numEntries);
646
+ | Platform.putLong( $data, ${baseOffset + 8 } + $keySectionSize, $numEntries);
647
+ | $unsafeMapData.pointTo( $data, $baseOffset, (int) $byteArraySize);
648
+ | ArrayData $keyArrayData = $unsafeMapData.keyArray();
649
+ | ArrayData $valueArrayData = $unsafeMapData.valueArray();
650
+ | $assignmentLoop
651
+ | $mapData = $unsafeMapData;
652
+ |}
653
+ """ .stripMargin
654
+ }
655
+
656
+ private def genCodeForAnyElements (
657
+ ctx : CodegenContext ,
658
+ childVariable : String ,
659
+ mapData : String ,
660
+ numEntries : String ): String = {
661
+ val keys = ctx.freshName(" keys" )
662
+ val values = ctx.freshName(" values" )
663
+ val mapDataClass = classOf [ArrayBasedMapData ].getName()
664
+
665
+ val isValuePrimitive = CodeGenerator .isPrimitiveType(dataType.valueType)
666
+ val valueAssignment = (entry : String , idx : String ) => {
667
+ val value = CodeGenerator .getValue(entry, dataType.valueType, " 1" )
668
+ if (dataType.valueContainsNull && isValuePrimitive) {
669
+ s " $values[ $idx] = $entry.isNullAt(1) ? null : (Object) $value; "
670
+ } else {
671
+ s " $values[ $idx] = $value; "
672
+ }
673
+ }
674
+ val keyAssignment = (key : String , idx : String ) => s " $keys[ $idx] = $key; "
675
+ val assignmentLoop = genCodeForAssignmentLoop(
676
+ ctx,
677
+ childVariable,
678
+ mapData,
679
+ numEntries,
680
+ keyAssignment,
681
+ valueAssignment)
682
+
683
+ s """
684
+ |final Object[] $keys = new Object[ $numEntries];
685
+ |final Object[] $values = new Object[ $numEntries];
686
+ | $assignmentLoop
687
+ | $mapData = $mapDataClass.apply( $keys, $values);
688
+ """ .stripMargin
689
+ }
690
+
691
+ override def prettyName : String = " map_from_entries"
692
+ }
693
+
694
+
478
695
/**
479
696
* Common base class for [[SortArray ]] and [[ArraySort ]].
480
697
*/
@@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
1990
2207
} else {
1991
2208
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
1992
2209
}
1993
- if (childDataType.containsNull) nullElementsProtection(ev , c, code) else code
2210
+ ctx.nullArrayElementsSaveExec (childDataType.containsNull, ev.isNull , c)( code)
1994
2211
})
1995
2212
}
1996
2213
1997
- private def nullElementsProtection (
1998
- ev : ExprCode ,
1999
- childVariableName : String ,
2000
- coreLogic : String ): String = {
2001
- s """
2002
- |for (int z = 0; ! ${ev.isNull} && z < $childVariableName.numElements(); z++) {
2003
- | ${ev.isNull} |= $childVariableName.isNullAt(z);
2004
- |}
2005
- |if (! ${ev.isNull}) {
2006
- | $coreLogic
2007
- |}
2008
- """ .stripMargin
2009
- }
2010
-
2011
2214
private def genCodeForNumberOfElements (
2012
2215
ctx : CodegenContext ,
2013
2216
childVariableName : String ) : (String , String ) = {
0 commit comments