Skip to content

Commit 30f5d0f

Browse files
alex7c4HyukjinKwonAlexander Koryagin
committed
[SPARK-23401][PYTHON][TESTS] Add more data types for PandasUDFTests
## What changes were proposed in this pull request? Add more data types for Pandas UDF Tests for PySpark SQL ## How was this patch tested? manual tests Closes apache#22568 from AlexanderKoryagin/new_types_for_pandas_udf_tests. Lead-authored-by: Aleksandr Koriagin <[email protected]> Co-authored-by: hyukjinkwon <[email protected]> Co-authored-by: Alexander Koryagin <[email protected]> Signed-off-by: hyukjinkwon <[email protected]>
1 parent 21f0b73 commit 30f5d0f

File tree

1 file changed

+79
-28
lines changed

1 file changed

+79
-28
lines changed

python/pyspark/sql/tests.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5525,32 +5525,81 @@ def data(self):
55255525
.withColumn("v", explode(col('vs'))).drop('vs')
55265526

55275527
def test_supported_types(self):
5528-
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
5529-
df = self.data.withColumn("arr", array(col("id")))
5528+
from decimal import Decimal
5529+
from distutils.version import LooseVersion
5530+
import pyarrow as pa
5531+
from pyspark.sql.functions import pandas_udf, PandasUDFType
55305532

5531-
# Different forms of group map pandas UDF, results of these are the same
5533+
values = [
5534+
1, 2, 3,
5535+
4, 5, 1.1,
5536+
2.2, Decimal(1.123),
5537+
[1, 2, 2], True, 'hello'
5538+
]
5539+
output_fields = [
5540+
('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
5541+
('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
5542+
('double', DoubleType()), ('decim', DecimalType(10, 3)),
5543+
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType())
5544+
]
55325545

5533-
output_schema = StructType(
5534-
[StructField('id', LongType()),
5535-
StructField('v', IntegerType()),
5536-
StructField('arr', ArrayType(LongType())),
5537-
StructField('v1', DoubleType()),
5538-
StructField('v2', LongType())])
5546+
# TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0
5547+
if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
5548+
values.append(bytearray([0x01, 0x02]))
5549+
output_fields.append(('bin', BinaryType()))
55395550

5551+
output_schema = StructType([StructField(*x) for x in output_fields])
5552+
df = self.spark.createDataFrame([values], schema=output_schema)
5553+
5554+
# Different forms of group map pandas UDF, results of these are the same
55405555
udf1 = pandas_udf(
5541-
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
5556+
lambda pdf: pdf.assign(
5557+
byte=pdf.byte * 2,
5558+
short=pdf.short * 2,
5559+
int=pdf.int * 2,
5560+
long=pdf.long * 2,
5561+
float=pdf.float * 2,
5562+
double=pdf.double * 2,
5563+
decim=pdf.decim * 2,
5564+
bool=False if pdf.bool else True,
5565+
str=pdf.str + 'there',
5566+
array=pdf.array,
5567+
),
55425568
output_schema,
55435569
PandasUDFType.GROUPED_MAP
55445570
)
55455571

55465572
udf2 = pandas_udf(
5547-
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
5573+
lambda _, pdf: pdf.assign(
5574+
byte=pdf.byte * 2,
5575+
short=pdf.short * 2,
5576+
int=pdf.int * 2,
5577+
long=pdf.long * 2,
5578+
float=pdf.float * 2,
5579+
double=pdf.double * 2,
5580+
decim=pdf.decim * 2,
5581+
bool=False if pdf.bool else True,
5582+
str=pdf.str + 'there',
5583+
array=pdf.array,
5584+
),
55485585
output_schema,
55495586
PandasUDFType.GROUPED_MAP
55505587
)
55515588

55525589
udf3 = pandas_udf(
5553-
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
5590+
lambda key, pdf: pdf.assign(
5591+
id=key[0],
5592+
byte=pdf.byte * 2,
5593+
short=pdf.short * 2,
5594+
int=pdf.int * 2,
5595+
long=pdf.long * 2,
5596+
float=pdf.float * 2,
5597+
double=pdf.double * 2,
5598+
decim=pdf.decim * 2,
5599+
bool=False if pdf.bool else True,
5600+
str=pdf.str + 'there',
5601+
array=pdf.array,
5602+
),
55545603
output_schema,
55555604
PandasUDFType.GROUPED_MAP
55565605
)
@@ -5714,24 +5763,26 @@ def test_wrong_args(self):
57145763
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
57155764

57165765
def test_unsupported_types(self):
5766+
from distutils.version import LooseVersion
5767+
import pyarrow as pa
57175768
from pyspark.sql.functions import pandas_udf, PandasUDFType
5718-
schema = StructType(
5719-
[StructField("id", LongType(), True),
5720-
StructField("map", MapType(StringType(), IntegerType()), True)])
5721-
with QuietTest(self.sc):
5722-
with self.assertRaisesRegexp(
5723-
NotImplementedError,
5724-
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
5725-
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
57265769

5727-
schema = StructType(
5728-
[StructField("id", LongType(), True),
5729-
StructField("arr_ts", ArrayType(TimestampType()), True)])
5730-
with QuietTest(self.sc):
5731-
with self.assertRaisesRegexp(
5732-
NotImplementedError,
5733-
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
5734-
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
5770+
common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
5771+
unsupported_types = [
5772+
StructField('map', MapType(StringType(), IntegerType())),
5773+
StructField('arr_ts', ArrayType(TimestampType())),
5774+
StructField('null', NullType()),
5775+
]
5776+
5777+
# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
5778+
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
5779+
unsupported_types.append(StructField('bin', BinaryType()))
5780+
5781+
for unsupported_type in unsupported_types:
5782+
schema = StructType([StructField('id', LongType(), True), unsupported_type])
5783+
with QuietTest(self.sc):
5784+
with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
5785+
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
57355786

57365787
# Regression test for SPARK-23314
57375788
def test_timestamp_dst(self):

0 commit comments

Comments
 (0)