@@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp}
21
21
22
22
import scala .collection .JavaConverters ._
23
23
import scala .reflect .ClassTag
24
+ import scala .reflect .runtime .universe .TypeTag
24
25
import scala .util .Random
25
26
26
27
import org .apache .spark .{SparkConf , SparkFunSuite }
27
28
import org .apache .spark .serializer .{JavaSerializer , KryoSerializer }
28
29
import org .apache .spark .sql .{RandomDataGenerator , Row }
29
- import org .apache .spark .sql .catalyst .InternalRow
30
+ import org .apache .spark .sql .catalyst .{ CatalystTypeConverters , InternalRow , JavaTypeInference , ScalaReflection }
30
31
import org .apache .spark .sql .catalyst .analysis .{ResolveTimeZone , SimpleAnalyzer , UnresolvedDeserializer }
31
32
import org .apache .spark .sql .catalyst .dsl .expressions ._
32
33
import org .apache .spark .sql .catalyst .encoders ._
@@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
501
502
InternalRow .fromSeq(Seq (Row (1 ))),
502
503
" java.lang.Integer is not a valid external type for schema of double" )
503
504
}
505
+
506
+ private def javaMapSerializerFor (
507
+ keyClazz : Class [_],
508
+ valueClazz : Class [_])(inputObject : Expression ): Expression = {
509
+
510
+ def kvSerializerFor (inputObject : Expression , clazz : Class [_]): Expression = clazz match {
511
+ case c if c == classOf [java.lang.Integer ] =>
512
+ Invoke (inputObject, " intValue" , IntegerType )
513
+ case c if c == classOf [java.lang.String ] =>
514
+ StaticInvoke (
515
+ classOf [UTF8String ],
516
+ StringType ,
517
+ " fromString" ,
518
+ inputObject :: Nil ,
519
+ returnNullable = false )
520
+ }
521
+
522
+ ExternalMapToCatalyst (
523
+ inputObject,
524
+ ObjectType (keyClazz),
525
+ kvSerializerFor(_, keyClazz),
526
+ keyNullable = true ,
527
+ ObjectType (valueClazz),
528
+ kvSerializerFor(_, valueClazz),
529
+ valueNullable = true
530
+ )
531
+ }
532
+
533
+ private def scalaMapSerializerFor [T : TypeTag , U : TypeTag ](inputObject : Expression ): Expression = {
534
+ import org .apache .spark .sql .catalyst .ScalaReflection ._
535
+
536
+ val curId = new java.util.concurrent.atomic.AtomicInteger ()
537
+
538
+ def kvSerializerFor [V : TypeTag ](inputObject : Expression ): Expression =
539
+ localTypeOf[V ].dealias match {
540
+ case t if t <:< localTypeOf[java.lang.Integer ] =>
541
+ Invoke (inputObject, " intValue" , IntegerType )
542
+ case t if t <:< localTypeOf[String ] =>
543
+ StaticInvoke (
544
+ classOf [UTF8String ],
545
+ StringType ,
546
+ " fromString" ,
547
+ inputObject :: Nil ,
548
+ returnNullable = false )
549
+ case _ =>
550
+ inputObject
551
+ }
552
+
553
+ ExternalMapToCatalyst (
554
+ inputObject,
555
+ dataTypeFor[T ],
556
+ kvSerializerFor[T ],
557
+ keyNullable = ! localTypeOf[T ].typeSymbol.asClass.isPrimitive,
558
+ dataTypeFor[U ],
559
+ kvSerializerFor[U ],
560
+ valueNullable = ! localTypeOf[U ].typeSymbol.asClass.isPrimitive
561
+ )
562
+ }
563
+
564
+ test(" SPARK-23589 ExternalMapToCatalyst should support interpreted execution" ) {
565
+ // Simple test
566
+ val scalaMap = scala.collection.Map [Int , String ](0 -> " v0" , 1 -> " v1" , 2 -> null , 3 -> " v3" )
567
+ val javaMap = new java.util.HashMap [java.lang.Integer , java.lang.String ]() {
568
+ {
569
+ put(0 , " v0" )
570
+ put(1 , " v1" )
571
+ put(2 , null )
572
+ put(3 , " v3" )
573
+ }
574
+ }
575
+ val expected = CatalystTypeConverters .convertToCatalyst(scalaMap)
576
+
577
+ // Java Map
578
+ val serializer1 = javaMapSerializerFor(classOf [java.lang.Integer ], classOf [java.lang.String ])(
579
+ Literal .fromObject(javaMap))
580
+ checkEvaluation(serializer1, expected)
581
+
582
+ // Scala Map
583
+ val serializer2 = scalaMapSerializerFor[Int , String ](Literal .fromObject(scalaMap))
584
+ checkEvaluation(serializer2, expected)
585
+
586
+ // NULL key test
587
+ val scalaMapHasNullKey = scala.collection.Map [java.lang.Integer , String ](
588
+ null .asInstanceOf [java.lang.Integer ] -> " v0" , new java.lang.Integer (1 ) -> " v1" )
589
+ val javaMapHasNullKey = new java.util.HashMap [java.lang.Integer , java.lang.String ]() {
590
+ {
591
+ put(null , " v0" )
592
+ put(1 , " v1" )
593
+ }
594
+ }
595
+
596
+ // Java Map
597
+ val serializer3 =
598
+ javaMapSerializerFor(classOf [java.lang.Integer ], classOf [java.lang.String ])(
599
+ Literal .fromObject(javaMapHasNullKey))
600
+ checkExceptionInExpression[RuntimeException ](
601
+ serializer3, EmptyRow , " Cannot use null as map key!" )
602
+
603
+ // Scala Map
604
+ val serializer4 = scalaMapSerializerFor[java.lang.Integer , String ](
605
+ Literal .fromObject(scalaMapHasNullKey))
606
+
607
+ checkExceptionInExpression[RuntimeException ](
608
+ serializer4, EmptyRow , " Cannot use null as map key!" )
609
+ }
504
610
}
505
611
506
612
class TestBean extends Serializable {
0 commit comments