|
39 | 39 | StructType,
|
40 | 40 | TimestampType,
|
41 | 41 | )
|
| 42 | +from pyspark.loose_version import LooseVersion |
| 43 | +from pyspark.testing.utils import ( |
| 44 | + have_pyarrow, |
| 45 | + have_pandas, |
| 46 | + have_numpy, |
| 47 | + pyarrow_requirement_message, |
| 48 | + pandas_requirement_message, |
| 49 | + numpy_requirement_message, |
| 50 | +) |
42 | 51 | from pyspark.testing.sqlutils import ReusedSQLTestCase
|
43 | 52 | from .type_table_utils import generate_table_diff, format_type_table
|
44 | 53 |
|
| 54 | +if have_numpy: |
| 55 | + import numpy as np |
| 56 | + |
45 | 57 |
|
| 58 | +@unittest.skipIf( |
| 59 | + not have_pandas |
| 60 | + or not have_pyarrow |
| 61 | + or not have_numpy |
| 62 | + or LooseVersion(np.__version__) < LooseVersion("2.0.0"), |
| 63 | + pandas_requirement_message or pyarrow_requirement_message or numpy_requirement_message, |
| 64 | +) |
46 | 65 | class UDFInputTypeTests(ReusedSQLTestCase):
|
47 | 66 | @classmethod
|
48 | 67 | def setUpClass(cls):
|
@@ -115,38 +134,7 @@ def value_udf(x):
|
115 | 134 | return x
|
116 | 135 |
|
117 | 136 | def value_str(x):
|
118 |
| - class NpPrintable: |
119 |
| - def __init__(self, x): |
120 |
| - self.x = x |
121 |
| - |
122 |
| - def __repr__(self): |
123 |
| - return f"np.{self.x.dtype}({self.x.item()})" |
124 |
| - |
125 |
| - # Numpy 1.x __repr__ returns a different format, see |
126 |
| - # https://numpy.org/doc/stable/release/2.0.0-notes.html#representation-of-numpy-scalars-changed # noqa: E501 |
127 |
| - # We only care about types and values of the elements, |
128 |
| - # so we accept this difference and implement our own repr to make |
129 |
| - # tests with numpy 1 return the same format as numpy 2. |
130 |
| - def convert_to_numpy_printable(x): |
131 |
| - import numpy as np |
132 |
| - |
133 |
| - if isinstance(x, Row): |
134 |
| - converted_values = tuple(convert_to_numpy_printable(v) for v in x) |
135 |
| - new_row = Row(*converted_values) |
136 |
| - new_row.__fields__ = x.__fields__ |
137 |
| - return new_row |
138 |
| - elif isinstance(x, (list)): |
139 |
| - return [convert_to_numpy_printable(elem) for elem in x] |
140 |
| - elif isinstance(x, tuple): |
141 |
| - return tuple(convert_to_numpy_printable(elem) for elem in x) |
142 |
| - elif isinstance(x, dict): |
143 |
| - return {k: convert_to_numpy_printable(v) for k, v in x.items()} |
144 |
| - elif isinstance(x, np.generic): |
145 |
| - return NpPrintable(x) |
146 |
| - else: |
147 |
| - return x |
148 |
| - |
149 |
| - return str(convert_to_numpy_printable(x)) |
| 137 | + return str(x) |
150 | 138 |
|
151 | 139 | type_test_udf = udf(type_udf, returnType=StringType(), useArrow=use_arrow)
|
152 | 140 | value_test_udf = udf(value_udf, returnType=spark_type, useArrow=use_arrow)
|
|
0 commit comments