Skip to content

Commit ae69206

Browse files
String dtype: avoid surfacing pyarrow excetion in binary operations
1 parent 360597c commit ae69206

File tree

4 files changed

+51
-61
lines changed

4 files changed

+51
-61
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,19 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
736736
)
737737
return ArrowExtensionArray(result)
738738

739+
def _op_method_error_message(self, other, op) -> str:
740+
if hasattr(other, "dtype"):
741+
other_type = f"dtype '{other.dtype}'"
742+
else:
743+
other_type = f"object of type {type(other)}"
744+
return (
745+
f"operation '{op.__name__}' not supported for "
746+
f"dtype '{self.dtype}' with {other_type}"
747+
)
748+
739749
def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
740750
pa_type = self._pa_array.type
751+
other_original = other
741752
other = self._box_pa(other)
742753

743754
if (
@@ -747,10 +758,15 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
747758
):
748759
if op in [operator.add, roperator.radd]:
749760
sep = pa.scalar("", type=pa_type)
750-
if op is operator.add:
751-
result = pc.binary_join_element_wise(self._pa_array, other, sep)
752-
elif op is roperator.radd:
753-
result = pc.binary_join_element_wise(other, self._pa_array, sep)
761+
try:
762+
if op is operator.add:
763+
result = pc.binary_join_element_wise(self._pa_array, other, sep)
764+
elif op is roperator.radd:
765+
result = pc.binary_join_element_wise(other, self._pa_array, sep)
766+
except pa.lib.ArrowNotImplementedError as err:
767+
raise TypeError(
768+
self._op_method_error_message(other_original, op)
769+
) from err
754770
return type(self)(result)
755771
elif op in [operator.mul, roperator.rmul]:
756772
binary = self._pa_array
@@ -782,9 +798,14 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
782798

783799
pc_func = arrow_funcs[op.__name__]
784800
if pc_func is NotImplemented:
801+
if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
802+
raise TypeError(self._op_method_error_message(other_original, op))
785803
raise NotImplementedError(f"{op.__name__} not implemented.")
786804

787-
result = pc_func(self._pa_array, other)
805+
try:
806+
result = pc_func(self._pa_array, other)
807+
except pa.lib.ArrowNotImplementedError as err:
808+
raise TypeError(self._op_method_error_message(other_original, op)) from err
788809
return type(self)(result)
789810

790811
def _logical_method(self, other, op) -> Self:

pandas/tests/arrays/boolean/test_arithmetic.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
8-
from pandas.compat import HAS_PYARROW
9-
106
import pandas as pd
117
import pandas._testing as tm
128

@@ -94,19 +90,8 @@ def test_op_int8(left_array, right_array, opname):
9490
# -----------------------------------------------------------------------------
9591

9692

97-
@pytest.mark.xfail(
98-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
99-
)
100-
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
93+
def test_error_invalid_values(data, all_arithmetic_operators):
10194
# invalid ops
102-
103-
if using_infer_string:
104-
import pyarrow as pa
105-
106-
err = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
107-
else:
108-
err = TypeError
109-
11095
op = all_arithmetic_operators
11196
s = pd.Series(data)
11297
ops = getattr(s, op)
@@ -116,7 +101,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
116101
"did not contain a loop with signature matching types|"
117102
"BooleanArray cannot perform the operation|"
118103
"not supported for the input types, and the inputs could not be safely coerced "
119-
"to any supported types according to the casting rule ''safe''"
104+
"to any supported types according to the casting rule ''safe''|"
105+
"not supported for dtype"
120106
)
121107
with pytest.raises(TypeError, match=msg):
122108
ops("foo")
@@ -125,9 +111,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
125111
r"unsupported operand type\(s\) for",
126112
"Concatenation operation is not implemented for NumPy arrays",
127113
"has no kernel",
114+
"not supported for dtype",
128115
]
129116
)
130-
with pytest.raises(err, match=msg):
117+
with pytest.raises(TypeError, match=msg):
131118
ops(pd.Timestamp("20180101"))
132119

133120
# invalid array-likes
@@ -140,7 +127,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
140127
"not all arguments converted during string formatting",
141128
"has no kernel",
142129
"not implemented",
130+
"not supported for dtype",
143131
]
144132
)
145-
with pytest.raises(err, match=msg):
133+
with pytest.raises(TypeError, match=msg):
146134
ops(pd.Series("foo", index=s.index))

pandas/tests/arrays/floating/test_arithmetic.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
import pandas as pd
97
import pandas._testing as tm
108
from pandas.core.arrays import FloatingArray
@@ -124,19 +122,11 @@ def test_arith_zero_dim_ndarray(other):
124122
# -----------------------------------------------------------------------------
125123

126124

127-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
128-
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
125+
def test_error_invalid_values(data, all_arithmetic_operators):
129126
op = all_arithmetic_operators
130127
s = pd.Series(data)
131128
ops = getattr(s, op)
132129

133-
if using_infer_string:
134-
import pyarrow as pa
135-
136-
errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
137-
else:
138-
errs = TypeError
139-
140130
# invalid scalars
141131
msg = "|".join(
142132
[
@@ -152,15 +142,17 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
152142
"Concatenation operation is not implemented for NumPy arrays",
153143
"has no kernel",
154144
"not implemented",
145+
"not supported for dtype",
146+
"Can only string multiply by an integer",
155147
]
156148
)
157-
with pytest.raises(errs, match=msg):
149+
with pytest.raises(TypeError, match=msg):
158150
ops("foo")
159-
with pytest.raises(errs, match=msg):
151+
with pytest.raises(TypeError, match=msg):
160152
ops(pd.Timestamp("20180101"))
161153

162154
# invalid array-likes
163-
with pytest.raises(errs, match=msg):
155+
with pytest.raises(TypeError, match=msg):
164156
ops(pd.Series("foo", index=s.index))
165157

166158
msg = "|".join(
@@ -181,9 +173,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
181173
"cannot subtract DatetimeArray from ndarray",
182174
"has no kernel",
183175
"not implemented",
176+
"not supported for dtype",
184177
]
185178
)
186-
with pytest.raises(errs, match=msg):
179+
with pytest.raises(TypeError, match=msg):
187180
ops(pd.Series(pd.date_range("20180101", periods=len(s))))
188181

189182

pandas/tests/arrays/integer/test_arithmetic.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
import pandas as pd
97
import pandas._testing as tm
108
from pandas.core import ops
@@ -174,19 +172,11 @@ def test_numpy_zero_dim_ndarray(other):
174172
# -----------------------------------------------------------------------------
175173

176174

177-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
178175
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
179176
op = all_arithmetic_operators
180177
s = pd.Series(data)
181178
ops = getattr(s, op)
182179

183-
if using_infer_string:
184-
import pyarrow as pa
185-
186-
errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
187-
else:
188-
errs = TypeError
189-
190180
# invalid scalars
191181
msg = "|".join(
192182
[
@@ -201,24 +191,21 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
201191
"has no kernel",
202192
"not implemented",
203193
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.",
194+
"not supported for dtype",
204195
]
205196
)
206-
with pytest.raises(errs, match=msg):
197+
with pytest.raises(TypeError, match=msg):
207198
ops("foo")
208-
with pytest.raises(errs, match=msg):
199+
with pytest.raises(TypeError, match=msg):
209200
ops(pd.Timestamp("20180101"))
210201

211202
# invalid array-likes
212203
str_ser = pd.Series("foo", index=s.index)
213204
# with pytest.raises(TypeError, match=msg):
214-
if (
215-
all_arithmetic_operators
216-
in [
217-
"__mul__",
218-
"__rmul__",
219-
]
220-
and not using_infer_string
221-
): # (data[~data.isna()] >= 0).all():
205+
if all_arithmetic_operators in [
206+
"__mul__",
207+
"__rmul__",
208+
]: # (data[~data.isna()] >= 0).all():
222209
res = ops(str_ser)
223210
expected = pd.Series(["foo" * x for x in data], index=s.index)
224211
expected = expected.fillna(np.nan)
@@ -227,7 +214,7 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
227214
# more-correct than np.nan here.
228215
tm.assert_series_equal(res, expected)
229216
else:
230-
with pytest.raises(errs, match=msg):
217+
with pytest.raises(TypeError, match=msg):
231218
ops(str_ser)
232219

233220
msg = "|".join(
@@ -242,9 +229,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
242229
"cannot subtract DatetimeArray from ndarray",
243230
"has no kernel",
244231
"not implemented",
232+
"not supported for dtype",
245233
]
246234
)
247-
with pytest.raises(errs, match=msg):
235+
with pytest.raises(TypeError, match=msg):
248236
ops(pd.Series(pd.date_range("20180101", periods=len(s))))
249237

250238

0 commit comments

Comments
 (0)