@@ -39,6 +39,24 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
3939 """ [{"id":4,"legs":[]}]""" ::
4040 """ [{"id":5,"legs":null}]""" :: Nil
4141
42+ test(" IsPrimitive should work as expected" ) {
43+ assert(SparkUtils .isPrimitive(BooleanType ))
44+ assert(SparkUtils .isPrimitive(ByteType ))
45+ assert(SparkUtils .isPrimitive(ShortType ))
46+ assert(SparkUtils .isPrimitive(IntegerType ))
47+ assert(SparkUtils .isPrimitive(LongType ))
48+ assert(SparkUtils .isPrimitive(FloatType ))
49+ assert(SparkUtils .isPrimitive(DoubleType ))
50+ assert(SparkUtils .isPrimitive(DecimalType (10 , 2 )))
51+ assert(SparkUtils .isPrimitive(StringType ))
52+ assert(SparkUtils .isPrimitive(BinaryType ))
53+ assert(SparkUtils .isPrimitive(DateType ))
54+ assert(SparkUtils .isPrimitive(TimestampType ))
55+ assert(! SparkUtils .isPrimitive(ArrayType (StringType )))
56+ assert(! SparkUtils .isPrimitive(StructType (Seq (StructField (" a" , StringType )))))
57+ assert(! SparkUtils .isPrimitive(MapType (StringType , StringType )))
58+ }
59+
4260 test(" Test schema flattening of multiple nested structure" ) {
4361 val expectedOrigSchema =
4462 """ root
@@ -626,6 +644,73 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
626644 assert(newDf.schema.fields.head.metadata.getLong(" maxLength" ) == 120 )
627645 }
628646
647+ test(" copyMetadata should copy primitive data types when it is enabled" ) {
648+ val schemaFrom = StructType (
649+ Seq (
650+ StructField (" int_field1" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test1" ).build()),
651+ StructField (" string_field" , StringType , nullable = true , metadata = new MetadataBuilder ().putLong(" maxLength" , 120 ).build()),
652+ StructField (" int_field2" , StructType (
653+ Seq (
654+ StructField (" int_field20" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test20" ).build())
655+ )
656+ ), nullable = true ),
657+ StructField (" struct_field2" , StructType (
658+ Seq (
659+ StructField (" int_field3" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test3" ).build())
660+ )
661+ ), nullable = true ),
662+ StructField (" array_string" , ArrayType (StringType ), nullable = true , metadata = new MetadataBuilder ().putLong(" maxLength" , 60 ).build()),
663+ StructField (" array_struct" , ArrayType (StructType (
664+ Seq (
665+ StructField (" int_field4" , IntegerType , nullable = true , metadata = new MetadataBuilder ().putString(" comment" , " Test4" ).build())
666+ )
667+ )), nullable = true ),
668+ )
669+ )
670+
671+ val schemaTo = StructType (
672+ Seq (
673+ StructField (" int_field1" , BooleanType , nullable = true ),
674+ StructField (" string_field" , IntegerType , nullable = true ),
675+ StructField (" int_field2" , IntegerType , nullable = true ),
676+ StructField (" struct_field2" , StructType (
677+ Seq (
678+ StructField (" int_field3" , BooleanType , nullable = true )
679+ )
680+ ), nullable = true ),
681+ StructField (" array_string" , ArrayType (IntegerType ), nullable = true ),
682+ StructField (" array_struct" , ArrayType (StructType (
683+ Seq (
684+ StructField (" int_field4" , StringType , nullable = true )
685+ )
686+ )), nullable = true ),
687+ )
688+ )
689+
690+ val schemaWithMetadata = SparkUtils .copyMetadata(schemaFrom, schemaTo, copyDataType = true )
691+ val fields = schemaWithMetadata.fields
692+
693+ // Ensure data types are copied
694+ assert(fields.head.dataType == IntegerType )
695+ assert(fields(1 ).dataType == StringType )
696+ assert(fields(2 ).dataType == IntegerType )
697+ assert(fields(3 ).dataType.isInstanceOf [StructType ])
698+ assert(fields(4 ).dataType.isInstanceOf [ArrayType ])
699+ assert(fields(5 ).dataType.isInstanceOf [ArrayType ])
700+
701+ assert(fields(3 ).dataType.asInstanceOf [StructType ].fields.head.dataType == IntegerType )
702+ assert(fields(4 ).dataType.asInstanceOf [ArrayType ].elementType == StringType )
703+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.isInstanceOf [StructType ])
704+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.asInstanceOf [StructType ].fields.head.dataType == IntegerType )
705+
706+ // Ensure metadata is copied
707+ assert(fields.head.metadata.getString(" comment" ) == " Test1" )
708+ assert(fields(1 ).metadata.getLong(" maxLength" ) == 120 )
709+ assert(fields(3 ).dataType.asInstanceOf [StructType ].fields.head.metadata.getString(" comment" ) == " Test3" )
710+ assert(fields(4 ).metadata.getLong(" maxLength" ) == 60 )
711+ assert(fields(5 ).dataType.asInstanceOf [ArrayType ].elementType.asInstanceOf [StructType ].fields.head.metadata.getString(" comment" ) == " Test4" )
712+ }
713+
629714 test(" copyMetadata should retain metadata on conflicts by default" ) {
630715 val df1 = List (1 , 2 , 3 ).toDF(" col1" )
631716 val df2 = List (1 , 2 , 3 ).toDF(" col1" )
0 commit comments