Skip to content

Commit a13187c

Browse files
zhengruifengdongjoon-hyun
authored andcommitted
[SPARK-53676][PYTHON][TESTS] Skip UDF type check with numpy 1.x
### What changes were proposed in this pull request? Skip UDF type check in minimum dependency envs ### Why are the changes needed? the two scheduled jobs are still failing after fix #52247 due to different version of numpy/pandas/pyarrow/etc. Actually, we don't need to run this test in every envs, because the result depends on the combination of version of numpy/pandas/pyarrow/etc ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? PR builder with ``` default: '{"PYSPARK_IMAGE_TO_TEST": "python-minimum", "PYTHON_TO_TEST": "python3.10", "ENV_NAME": "PYTHON_MINIMUM"}' ``` https://github.com/zhengruifeng/spark/actions/runs/17940562561/job/51016142810 ### Was this patch authored or co-authored using generative AI tooling? no Closes #52419 from zhengruifeng/restore_old_dep_test. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent b6993cb commit a13187c

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,29 @@
3939
StructType,
4040
TimestampType,
4141
)
42+
from pyspark.loose_version import LooseVersion
43+
from pyspark.testing.utils import (
44+
have_pyarrow,
45+
have_pandas,
46+
have_numpy,
47+
pyarrow_requirement_message,
48+
pandas_requirement_message,
49+
numpy_requirement_message,
50+
)
4251
from pyspark.testing.sqlutils import ReusedSQLTestCase
4352
from .type_table_utils import generate_table_diff, format_type_table
4453

54+
if have_numpy:
55+
import numpy as np
56+
4557

58+
@unittest.skipIf(
59+
not have_pandas
60+
or not have_pyarrow
61+
or not have_numpy
62+
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
63+
pandas_requirement_message or pyarrow_requirement_message or numpy_requirement_message,
64+
)
4665
class UDFInputTypeTests(ReusedSQLTestCase):
4766
@classmethod
4867
def setUpClass(cls):
@@ -115,38 +134,7 @@ def value_udf(x):
115134
return x
116135

117136
def value_str(x):
118-
class NpPrintable:
119-
def __init__(self, x):
120-
self.x = x
121-
122-
def __repr__(self):
123-
return f"np.{self.x.dtype}({self.x.item()})"
124-
125-
# Numpy 1.x __repr__ returns a different format, see
126-
# https://numpy.org/doc/stable/release/2.0.0-notes.html#representation-of-numpy-scalars-changed # noqa: E501
127-
# We only care about types and values of the elements,
128-
# so we accept this difference and implement our own repr to make
129-
# tests with numpy 1 return the same format as numpy 2.
130-
def convert_to_numpy_printable(x):
131-
import numpy as np
132-
133-
if isinstance(x, Row):
134-
converted_values = tuple(convert_to_numpy_printable(v) for v in x)
135-
new_row = Row(*converted_values)
136-
new_row.__fields__ = x.__fields__
137-
return new_row
138-
elif isinstance(x, (list)):
139-
return [convert_to_numpy_printable(elem) for elem in x]
140-
elif isinstance(x, tuple):
141-
return tuple(convert_to_numpy_printable(elem) for elem in x)
142-
elif isinstance(x, dict):
143-
return {k: convert_to_numpy_printable(v) for k, v in x.items()}
144-
elif isinstance(x, np.generic):
145-
return NpPrintable(x)
146-
else:
147-
return x
148-
149-
return str(convert_to_numpy_printable(x))
137+
return str(x)
150138

151139
type_test_udf = udf(type_udf, returnType=StringType(), useArrow=use_arrow)
152140
value_test_udf = udf(value_udf, returnType=spark_type, useArrow=use_arrow)

python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,29 @@
4343
StructType,
4444
TimestampType,
4545
)
46+
from pyspark.loose_version import LooseVersion
47+
from pyspark.testing.utils import (
48+
have_pyarrow,
49+
have_pandas,
50+
have_numpy,
51+
pyarrow_requirement_message,
52+
pandas_requirement_message,
53+
numpy_requirement_message,
54+
)
4655
from pyspark.testing.sqlutils import ReusedSQLTestCase
4756
from .type_table_utils import generate_table_diff, format_type_table
4857

58+
if have_numpy:
59+
import numpy as np
60+
4961

62+
@unittest.skipIf(
63+
not have_pandas
64+
or not have_pyarrow
65+
or not have_numpy
66+
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
67+
pandas_requirement_message or pyarrow_requirement_message or numpy_requirement_message,
68+
)
5069
class UDFReturnTypeTests(ReusedSQLTestCase):
5170
@classmethod
5271
def setUpClass(cls):

0 commit comments

Comments
 (0)