Skip to content

Commit 153086f

Browse files
committed
Add _validate_setitem_value method to raise TypeError and fix tests
1 parent d6c9941 commit 153086f

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

pandas/core/arrays/numpy_.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,85 @@ def _validate_scalar(self, fill_value):
232232
fill_value = self.dtype.na_value
233233
return fill_value
234234

235+
def _validate_setitem_value(self, value):
236+
"""
237+
Check if we have a scalar that we can cast losslessly.
238+
239+
Raises
240+
------
241+
TypeError
242+
"""
243+
244+
kind = self.dtype.kind
245+
246+
if kind == "b":
247+
if lib.is_bool(value) or np.can_cast(type(value), self.dtype.type):
248+
return value
249+
if isinstance(value, NumpyExtensionArray) and (
250+
lib.is_bool_array(value.to_numpy())
251+
or lib.is_bool_list(value.to_numpy())
252+
):
253+
return value
254+
255+
elif kind == "i":
256+
if lib.is_integer(value) or np.can_cast(type(value), self.dtype.type):
257+
return value
258+
if isinstance(value, NumpyExtensionArray) and lib.is_integer_array(
259+
value.to_numpy()
260+
):
261+
return value
262+
263+
elif kind == "u":
264+
if (lib.is_integer(value) and value > -1) or np.can_cast(
265+
type(value), self.dtype.type
266+
):
267+
return value
268+
269+
elif kind == "c":
270+
if lib.is_complex(value) or np.can_cast(type(value), self.dtype.type):
271+
return value
272+
273+
elif kind == "S":
274+
if isinstance(value, str) or np.can_cast(type(value), self.dtype.type):
275+
return value
276+
if isinstance(value, NumpyExtensionArray) and lib.is_string_array(
277+
value.to_numpy()
278+
):
279+
return value
280+
281+
elif kind == "M":
282+
if isinstance(value, np.datetime64):
283+
return value
284+
if isinstance(value, NumpyExtensionArray) and (
285+
lib.is_date_array(value.to_numpy())
286+
or lib.is_datetime_array(value.to_numpy())
287+
or lib.is_datetime64_array(value.to_numpy())
288+
or lib.is_datetime_with_singletz_array(value.to_numpy())
289+
):
290+
return value
291+
292+
elif kind == "m":
293+
if isinstance(value, np.timedelta64):
294+
return value
295+
if isinstance(value, NumpyExtensionArray) and (
296+
lib.is_timedelta_or_timedelta64_array(value.to_numpy())
297+
or lib.is_time_array(value.to_numpy())
298+
):
299+
return value
300+
301+
elif kind == "f":
302+
if lib.is_float(value) or np.can_cast(type(value), self.dtype.type):
303+
return value
304+
if isinstance(value, NumpyExtensionArray) and lib.is_float_array(
305+
value.to_numpy()
306+
):
307+
return value
308+
309+
elif np.can_cast(type(value), self.dtype.type):
310+
return value
311+
312+
raise TypeError(f"Invalid value '{value!s}' for dtype {self.dtype}")
313+
235314
def _values_for_factorize(self) -> tuple[np.ndarray, float | None]:
236315
if self.dtype.kind in "iub":
237316
fv = None

pandas/tests/arrays/numpy_/test_numpy.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,15 @@ def test_setitem_object_typecode(dtype):
275275
def test_setitem_no_coercion():
276276
# https://github.com/pandas-dev/pandas/issues/28150
277277
arr = NumpyExtensionArray(np.array([1, 2, 3]))
278-
with pytest.raises(ValueError, match="int"):
278+
with pytest.raises(TypeError):
279279
arr[0] = "a"
280280

281281
# With a value that we do coerce, check that we coerce the value
282282
# and not the underlying array.
283-
arr[0] = 2.5
283+
with pytest.raises(TypeError):
284+
arr[0] = 2.5
285+
286+
arr[0] = 9
284287
assert isinstance(arr[0], (int, np.integer)), type(arr[0])
285288

286289

@@ -296,7 +299,10 @@ def test_setitem_preserves_views():
296299
assert view2[0] == 9
297300
assert view3[0] == 9
298301

299-
arr[-1] = 2.5
302+
with pytest.raises(TypeError):
303+
arr[-1] = 2.5
304+
305+
arr[-1] = 4
300306
view1[-1] = 5
301307
assert arr[-1] == 5
302308

0 commit comments

Comments
 (0)