@@ -5525,32 +5525,81 @@ def data(self):
5525
5525
.withColumn ("v" , explode (col ('vs' ))).drop ('vs' )
5526
5526
5527
5527
def test_supported_types (self ):
5528
- from pyspark .sql .functions import pandas_udf , PandasUDFType , array , col
5529
- df = self .data .withColumn ("arr" , array (col ("id" )))
5528
+ from decimal import Decimal
5529
+ from distutils .version import LooseVersion
5530
+ import pyarrow as pa
5531
+ from pyspark .sql .functions import pandas_udf , PandasUDFType
5530
5532
5531
- # Different forms of group map pandas UDF, results of these are the same
5533
+ values = [
5534
+ 1 , 2 , 3 ,
5535
+ 4 , 5 , 1.1 ,
5536
+ 2.2 , Decimal (1.123 ),
5537
+ [1 , 2 , 2 ], True , 'hello'
5538
+ ]
5539
+ output_fields = [
5540
+ ('id' , IntegerType ()), ('byte' , ByteType ()), ('short' , ShortType ()),
5541
+ ('int' , IntegerType ()), ('long' , LongType ()), ('float' , FloatType ()),
5542
+ ('double' , DoubleType ()), ('decim' , DecimalType (10 , 3 )),
5543
+ ('array' , ArrayType (IntegerType ())), ('bool' , BooleanType ()), ('str' , StringType ())
5544
+ ]
5532
5545
5533
- output_schema = StructType (
5534
- [StructField ('id' , LongType ()),
5535
- StructField ('v' , IntegerType ()),
5536
- StructField ('arr' , ArrayType (LongType ())),
5537
- StructField ('v1' , DoubleType ()),
5538
- StructField ('v2' , LongType ())])
5546
+ # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0
5547
+ if LooseVersion (pa .__version__ ) >= LooseVersion ("0.10.0" ):
5548
+ values .append (bytearray ([0x01 , 0x02 ]))
5549
+ output_fields .append (('bin' , BinaryType ()))
5539
5550
5551
+ output_schema = StructType ([StructField (* x ) for x in output_fields ])
5552
+ df = self .spark .createDataFrame ([values ], schema = output_schema )
5553
+
5554
+ # Different forms of group map pandas UDF, results of these are the same
5540
5555
udf1 = pandas_udf (
5541
- lambda pdf : pdf .assign (v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
5556
+ lambda pdf : pdf .assign (
5557
+ byte = pdf .byte * 2 ,
5558
+ short = pdf .short * 2 ,
5559
+ int = pdf .int * 2 ,
5560
+ long = pdf .long * 2 ,
5561
+ float = pdf .float * 2 ,
5562
+ double = pdf .double * 2 ,
5563
+ decim = pdf .decim * 2 ,
5564
+ bool = False if pdf .bool else True ,
5565
+ str = pdf .str + 'there' ,
5566
+ array = pdf .array ,
5567
+ ),
5542
5568
output_schema ,
5543
5569
PandasUDFType .GROUPED_MAP
5544
5570
)
5545
5571
5546
5572
udf2 = pandas_udf (
5547
- lambda _ , pdf : pdf .assign (v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
5573
+ lambda _ , pdf : pdf .assign (
5574
+ byte = pdf .byte * 2 ,
5575
+ short = pdf .short * 2 ,
5576
+ int = pdf .int * 2 ,
5577
+ long = pdf .long * 2 ,
5578
+ float = pdf .float * 2 ,
5579
+ double = pdf .double * 2 ,
5580
+ decim = pdf .decim * 2 ,
5581
+ bool = False if pdf .bool else True ,
5582
+ str = pdf .str + 'there' ,
5583
+ array = pdf .array ,
5584
+ ),
5548
5585
output_schema ,
5549
5586
PandasUDFType .GROUPED_MAP
5550
5587
)
5551
5588
5552
5589
udf3 = pandas_udf (
5553
- lambda key , pdf : pdf .assign (id = key [0 ], v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
5590
+ lambda key , pdf : pdf .assign (
5591
+ id = key [0 ],
5592
+ byte = pdf .byte * 2 ,
5593
+ short = pdf .short * 2 ,
5594
+ int = pdf .int * 2 ,
5595
+ long = pdf .long * 2 ,
5596
+ float = pdf .float * 2 ,
5597
+ double = pdf .double * 2 ,
5598
+ decim = pdf .decim * 2 ,
5599
+ bool = False if pdf .bool else True ,
5600
+ str = pdf .str + 'there' ,
5601
+ array = pdf .array ,
5602
+ ),
5554
5603
output_schema ,
5555
5604
PandasUDFType .GROUPED_MAP
5556
5605
)
@@ -5714,24 +5763,26 @@ def test_wrong_args(self):
5714
5763
pandas_udf (lambda x , y : x , DoubleType (), PandasUDFType .SCALAR ))
5715
5764
5716
5765
def test_unsupported_types (self ):
5766
+ from distutils .version import LooseVersion
5767
+ import pyarrow as pa
5717
5768
from pyspark .sql .functions import pandas_udf , PandasUDFType
5718
- schema = StructType (
5719
- [StructField ("id" , LongType (), True ),
5720
- StructField ("map" , MapType (StringType (), IntegerType ()), True )])
5721
- with QuietTest (self .sc ):
5722
- with self .assertRaisesRegexp (
5723
- NotImplementedError ,
5724
- 'Invalid returnType.*grouped map Pandas UDF.*MapType' ):
5725
- pandas_udf (lambda x : x , schema , PandasUDFType .GROUPED_MAP )
5726
5769
5727
- schema = StructType (
5728
- [StructField ("id" , LongType (), True ),
5729
- StructField ("arr_ts" , ArrayType (TimestampType ()), True )])
5730
- with QuietTest (self .sc ):
5731
- with self .assertRaisesRegexp (
5732
- NotImplementedError ,
5733
- 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType' ):
5734
- pandas_udf (lambda x : x , schema , PandasUDFType .GROUPED_MAP )
5770
+ common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
5771
+ unsupported_types = [
5772
+ StructField ('map' , MapType (StringType (), IntegerType ())),
5773
+ StructField ('arr_ts' , ArrayType (TimestampType ())),
5774
+ StructField ('null' , NullType ()),
5775
+ ]
5776
+
5777
+ # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
5778
+ if LooseVersion (pa .__version__ ) < LooseVersion ("0.10.0" ):
5779
+ unsupported_types .append (StructField ('bin' , BinaryType ()))
5780
+
5781
+ for unsupported_type in unsupported_types :
5782
+ schema = StructType ([StructField ('id' , LongType (), True ), unsupported_type ])
5783
+ with QuietTest (self .sc ):
5784
+ with self .assertRaisesRegexp (NotImplementedError , common_err_msg ):
5785
+ pandas_udf (lambda x : x , schema , PandasUDFType .GROUPED_MAP )
5735
5786
5736
5787
# Regression test for SPARK-23314
5737
5788
def test_timestamp_dst (self ):
0 commit comments