Skip to content

Commit 656ddc1

Browse files
committed
add extension test
1 parent 3598fc4 commit 656ddc1

File tree

5 files changed

+46
-0
lines changed

5 files changed

+46
-0
lines changed

pandas/tests/extension/base/setitem.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.core.dtypes.common import is_hashable
5+
46
import pandas as pd
57
import pandas._testing as tm
68

@@ -310,6 +312,22 @@ def test_setitem_expand_with_extension(self, data):
310312
result.loc[:, "B"] = data
311313
tm.assert_frame_equal(result, expected)
312314

315+
def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data):
316+
# GH#41626 retain index.dtype in setitem-with-expansion
317+
if not is_hashable(data[0]):
318+
pytest.skip("Test does not apply to non-hashable data.")
319+
data = data.unique()
320+
expected = pd.DataFrame({"A": range(len(data))}, index=data)
321+
df = expected.iloc[:-1]
322+
ser = df["A"]
323+
item = data[-1]
324+
325+
df.loc[item] = len(data) - 1
326+
tm.assert_frame_equal(df, expected)
327+
328+
ser.loc[item] = len(data) - 1
329+
tm.assert_series_equal(ser, expected["A"])
330+
313331
def test_setitem_frame_invalid_length(self, data):
314332
df = pd.DataFrame({"A": [1] * len(data)})
315333
xpr = (

pandas/tests/extension/test_arrow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,15 @@ def test_comp_masked_numpy(self, masked_dtype, comparison_op):
10671067
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
10681068
tm.assert_series_equal(result, expected)
10691069

1070+
def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request):
1071+
pa_dtype = data.dtype.pyarrow_dtype
1072+
if pa.types.is_date(pa_dtype):
1073+
mark = pytest.mark.xfail(
1074+
reason="GH#62343 incorrectly casts to timestamp[ms][pyarrow]"
1075+
)
1076+
request.applymarker(mark)
1077+
super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data)
1078+
10701079

10711080
class TestLogicalOps:
10721081
"""Various Series and DataFrame logical ops methods."""

pandas/tests/extension/test_interval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def test_EA_types(self, engine, data, request):
126126
def test_astype_str(self, data):
127127
super().test_astype_str(data)
128128

129+
@pytest.mark.xfail(
130+
reason="Test is invalid for IntervalDtype, needs to be adapted for "
131+
"this dtype with an index with index._index_as_unique."
132+
)
133+
def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data):
134+
super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data)
135+
129136

130137
# TODO: either belongs in tests.arrays.interval or move into base tests.
131138
def test_fillna_non_scalar_raises(data_missing):

pandas/tests/extension/test_masked.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,9 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
360360
)
361361
)
362362
tm.assert_series_equal(result, expected)
363+
364+
def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request):
365+
if data.dtype.kind == "b":
366+
mark = pytest.mark.xfail(reason="GH#62344 incorrectly casts to object")
367+
request.applymarker(mark)
368+
super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data)

pandas/tests/extension/test_numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ def test_index_from_listlike_with_dtype(self, data):
421421
def test_EA_types(self, engine, data, request):
422422
super().test_EA_types(engine, data, request)
423423

424+
def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request):
425+
if isinstance(data[-1], tuple):
426+
mark = pytest.mark.xfail(reason="Unpacks tuple")
427+
request.applymarker(mark)
428+
super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data)
429+
424430

425431
class Test2DCompat(base.NDArrayBacked2DTests):
426432
pass

0 commit comments

Comments
 (0)