Skip to content

Commit 9d23f2f

Browse files
benhurdelheyzhengruifeng
authored andcommitted
[SPARK-53355][PYTHON][SQL] fix numpy 1.x repr in type tests
### What changes were proposed in this pull request? - this is a minor followup to #52105, we noticed that the test breaks in two spark master runs with a different env - the root cause was that numpy 1.x implements `__repr__` differently ### Why are the changes needed? - fix `Build / Python-only (master, Minimum dependencies of PySpark)` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ran tests locally with numpy 1.22.4 ### Was this patch authored or co-authored using generative AI tooling? No Closes #52247 from benrobby/SPARK-53355-fix-numpy-repr. Authored-by: Ben Hurdelhey <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 551e7f2 commit 9d23f2f

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,38 @@ def value_udf(x):
115115
return x
116116

117117
def value_str(x):
118-
return str(x)
118+
class NpPrintable:
119+
def __init__(self, x):
120+
self.x = x
121+
122+
def __repr__(self):
123+
return f"np.{self.x.dtype}({self.x.item()})"
124+
125+
# Numpy 1.x __repr__ returns a different format, see
126+
# https://numpy.org/doc/stable/release/2.0.0-notes.html#representation-of-numpy-scalars-changed # noqa: E501
127+
# We only care about types and values of the elements,
128+
# so we accept this difference and implement our own repr to make
129+
# tests with numpy 1 return the same format as numpy 2.
130+
def convert_to_numpy_printable(x):
131+
import numpy as np
132+
133+
if isinstance(x, Row):
134+
converted_values = tuple(convert_to_numpy_printable(v) for v in x)
135+
new_row = Row(*converted_values)
136+
new_row.__fields__ = x.__fields__
137+
return new_row
138+
elif isinstance(x, (list)):
139+
return [convert_to_numpy_printable(elem) for elem in x]
140+
elif isinstance(x, tuple):
141+
return tuple(convert_to_numpy_printable(elem) for elem in x)
142+
elif isinstance(x, dict):
143+
return {k: convert_to_numpy_printable(v) for k, v in x.items()}
144+
elif isinstance(x, np.generic):
145+
return NpPrintable(x)
146+
else:
147+
return x
148+
149+
return str(convert_to_numpy_printable(x))
119150

120151
type_test_udf = udf(type_udf, returnType=StringType(), useArrow=use_arrow)
121152
value_test_udf = udf(value_udf, returnType=spark_type, useArrow=use_arrow)

0 commit comments

Comments
 (0)