@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
19
19
20
20
import java .lang .reflect .Modifier
21
21
22
+ import scala .collection .JavaConverters ._
22
23
import scala .collection .mutable .Builder
23
24
import scala .language .existentials
24
25
import scala .reflect .ClassTag
@@ -501,12 +502,22 @@ case class LambdaVariable(
501
502
value : String ,
502
503
isNull : String ,
503
504
dataType : DataType ,
504
- nullable : Boolean = true ) extends LeafExpression
505
- with Unevaluable with NonSQLExpression {
505
+ nullable : Boolean = true ) extends LeafExpression with NonSQLExpression {
506
+
507
+ // Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
508
+ override def eval (input : InternalRow ): Any = {
509
+ assert(input.numFields == 1 ,
510
+ " The input row of interpreted LambdaVariable should have only 1 field." )
511
+ input.get(0 , dataType)
512
+ }
506
513
507
514
override def genCode (ctx : CodegenContext ): ExprCode = {
508
515
ExprCode (code = " " , value = value, isNull = if (nullable) isNull else " false" )
509
516
}
517
+
518
+ // This won't be called as `genCode` is overrided, just overriding it to make
519
+ // `LambdaVariable` non-abstract.
520
+ override protected def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = ev
510
521
}
511
522
512
523
/**
@@ -599,8 +610,92 @@ case class MapObjects private(
599
610
600
611
override def children : Seq [Expression ] = lambdaFunction :: inputData :: Nil
601
612
602
- override def eval (input : InternalRow ): Any =
603
- throw new UnsupportedOperationException (" Only code-generated evaluation is supported" )
613
+ // The data with UserDefinedType are actually stored with the data type of its sqlType.
614
+ // When we want to apply MapObjects on it, we have to use it.
615
+ lazy private val inputDataType = inputData.dataType match {
616
+ case u : UserDefinedType [_] => u.sqlType
617
+ case _ => inputData.dataType
618
+ }
619
+
620
+ private def executeFuncOnCollection (inputCollection : Seq [_]): Iterator [_] = {
621
+ val row = new GenericInternalRow (1 )
622
+ inputCollection.toIterator.map { element =>
623
+ row.update(0 , element)
624
+ lambdaFunction.eval(row)
625
+ }
626
+ }
627
+
628
+ private lazy val convertToSeq : Any => Seq [_] = inputDataType match {
629
+ case ObjectType (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
630
+ _.asInstanceOf [Seq [_]]
631
+ case ObjectType (cls) if cls.isArray =>
632
+ _.asInstanceOf [Array [_]].toSeq
633
+ case ObjectType (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
634
+ _.asInstanceOf [java.util.List [_]].asScala
635
+ case ObjectType (cls) if cls == classOf [Object ] =>
636
+ (inputCollection) => {
637
+ if (inputCollection.getClass.isArray) {
638
+ inputCollection.asInstanceOf [Array [_]].toSeq
639
+ } else {
640
+ inputCollection.asInstanceOf [Seq [_]]
641
+ }
642
+ }
643
+ case ArrayType (et, _) =>
644
+ _.asInstanceOf [ArrayData ].array
645
+ }
646
+
647
+ private lazy val mapElements : Seq [_] => Any = customCollectionCls match {
648
+ case Some (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
649
+ // Scala sequence
650
+ executeFuncOnCollection(_).toSeq
651
+ case Some (cls) if classOf [scala.collection.Set [_]].isAssignableFrom(cls) =>
652
+ // Scala set
653
+ executeFuncOnCollection(_).toSet
654
+ case Some (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
655
+ // Java list
656
+ if (cls == classOf [java.util.List [_]] || cls == classOf [java.util.AbstractList [_]] ||
657
+ cls == classOf [java.util.AbstractSequentialList [_]]) {
658
+ // Specifying non concrete implementations of `java.util.List`
659
+ executeFuncOnCollection(_).toSeq.asJava
660
+ } else {
661
+ val constructors = cls.getConstructors()
662
+ val intParamConstructor = constructors.find { constructor =>
663
+ constructor.getParameterCount == 1 && constructor.getParameterTypes()(0 ) == classOf [Int ]
664
+ }
665
+ val noParamConstructor = constructors.find { constructor =>
666
+ constructor.getParameterCount == 0
667
+ }
668
+
669
+ val constructor = intParamConstructor.map { intConstructor =>
670
+ (len : Int ) => intConstructor.newInstance(len.asInstanceOf [Object ])
671
+ }.getOrElse {
672
+ (_ : Int ) => noParamConstructor.get.newInstance()
673
+ }
674
+
675
+ // Specifying concrete implementations of `java.util.List`
676
+ (inputs) => {
677
+ val results = executeFuncOnCollection(inputs)
678
+ val builder = constructor(inputs.length).asInstanceOf [java.util.List [Any ]]
679
+ results.foreach(builder.add(_))
680
+ builder
681
+ }
682
+ }
683
+ case None =>
684
+ // array
685
+ x => new GenericArrayData (executeFuncOnCollection(x).toArray)
686
+ case Some (cls) =>
687
+ throw new RuntimeException (s " class ` ${cls.getName}` is not supported by `MapObjects` as " +
688
+ " resulting collection." )
689
+ }
690
+
691
+ override def eval (input : InternalRow ): Any = {
692
+ val inputCollection = inputData.eval(input)
693
+
694
+ if (inputCollection == null ) {
695
+ return null
696
+ }
697
+ mapElements(convertToSeq(inputCollection))
698
+ }
604
699
605
700
override def dataType : DataType =
606
701
customCollectionCls.map(ObjectType .apply).getOrElse(
@@ -647,13 +742,6 @@ case class MapObjects private(
647
742
case _ => " "
648
743
}
649
744
650
- // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
651
- // When we want to apply MapObjects on it, we have to use it.
652
- val inputDataType = inputData.dataType match {
653
- case p : PythonUserDefinedType => p.sqlType
654
- case _ => inputData.dataType
655
- }
656
-
657
745
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
658
746
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
659
747
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
0 commit comments