Skip to content

Commit 0426e59

Browse files
Backport PR pandas-dev#62323 on branch 2.3.x (String dtype: keep select_dtypes(include=object) selecting string columns) (pandas-dev#62400)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent fd2f7eb commit 0426e59

File tree

4 files changed

+60
-13
lines changed

4 files changed

+60
-13
lines changed

doc/source/whatsnew/v2.3.3.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ Most changes in this release are related to :class:`StringDtype` which will
1818
become the default string dtype in pandas 3.0. See
1919
:ref:`whatsnew_230.upcoming_changes` for more details.
2020

21+
.. _whatsnew_233.string_fixes.improvements:
22+
23+
Improvements
24+
^^^^^^^^^^^^
25+
- Update :meth:`DataFrame.select_dtypes` to keep selecting ``str`` columns when
26+
specifying ``include=["object"]`` for backwards compatibility. In a future
27+
release, this will be deprecated and code for pandas 3+ should be updated to
28+
do ``include=["str"]`` (:issue:`61916`)
29+
30+
2131
.. _whatsnew_233.string_fixes.bugs:
2232

2333
Bug fixes

pandas/core/dtypes/cast.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,9 @@ def invalidate_string_dtypes(dtype_set: set[DtypeObj]) -> None:
966966
np.dtype("<U").type, # type: ignore[arg-type]
967967
}
968968
if non_string_dtypes != dtype_set:
969-
raise TypeError("string dtypes are not allowed, use 'object' instead")
969+
raise TypeError(
970+
"numpy string dtypes are not allowed, use 'str' or 'object' instead"
971+
)
970972

971973

972974
def coerce_indexer_dtype(indexer, categories) -> np.ndarray:

pandas/core/frame.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
TimedeltaArray,
145145
)
146146
from pandas.core.arrays.sparse import SparseFrameAccessor
147+
from pandas.core.arrays.string_ import StringDtype
147148
from pandas.core.construction import (
148149
ensure_wrapped_if_datetimelike,
149150
sanitize_array,
@@ -5080,10 +5081,19 @@ def check_int_infer_dtype(dtypes):
50805081
def dtype_predicate(dtype: DtypeObj, dtypes_set) -> bool:
50815082
# GH 46870: BooleanDtype._is_numeric == True but should be excluded
50825083
dtype = dtype if not isinstance(dtype, ArrowDtype) else dtype.numpy_dtype
5083-
return issubclass(dtype.type, tuple(dtypes_set)) or (
5084-
np.number in dtypes_set
5085-
and getattr(dtype, "_is_numeric", False)
5086-
and not is_bool_dtype(dtype)
5084+
return (
5085+
issubclass(dtype.type, tuple(dtypes_set))
5086+
or (
5087+
np.number in dtypes_set
5088+
and getattr(dtype, "_is_numeric", False)
5089+
and not is_bool_dtype(dtype)
5090+
)
5091+
# backwards compat for the default `str` dtype being selected by object
5092+
or (
5093+
isinstance(dtype, StringDtype)
5094+
and dtype.na_value is np.nan
5095+
and np.object_ in dtypes_set
5096+
)
50875097
)
50885098

50895099
def predicate(arr: ArrayLike) -> bool:

pandas/tests/frame/methods/test_select_dtypes.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def test_select_dtypes_include_using_list_like(self, using_infer_string):
102102
ri = df.select_dtypes(include=[str])
103103
tm.assert_frame_equal(ri, ei)
104104

105+
ri = df.select_dtypes(include=["object"])
106+
ei = df[["a"]]
107+
tm.assert_frame_equal(ri, ei)
108+
105109
def test_select_dtypes_exclude_using_list_like(self):
106110
df = DataFrame(
107111
{
@@ -309,17 +313,15 @@ def test_select_dtypes_not_an_attr_but_still_valid_dtype(self, using_infer_strin
309313
df["g"] = df.f.diff()
310314
assert not hasattr(np, "u8")
311315
r = df.select_dtypes(include=["i8", "O"], exclude=["timedelta"])
312-
if using_infer_string:
313-
e = df[["b"]]
314-
else:
315-
e = df[["a", "b"]]
316+
# if using_infer_string:
317+
# TODO warn
318+
e = df[["a", "b"]]
316319
tm.assert_frame_equal(r, e)
317320

318321
r = df.select_dtypes(include=["i8", "O", "timedelta64[ns]"])
319-
if using_infer_string:
320-
e = df[["b", "g"]]
321-
else:
322-
e = df[["a", "b", "g"]]
322+
# if using_infer_string:
323+
# TODO warn
324+
e = df[["a", "b", "g"]]
323325
tm.assert_frame_equal(r, e)
324326

325327
def test_select_dtypes_empty(self):
@@ -483,3 +485,26 @@ def test_select_dtypes_no_view(self):
483485
result = df.select_dtypes(include=["number"])
484486
result.iloc[0, 0] = 0
485487
tm.assert_frame_equal(df, df_orig)
488+
489+
def test_select_dtype_object_and_str(self, using_infer_string):
490+
# https://github.com/pandas-dev/pandas/issues/61916
491+
df = DataFrame(
492+
{
493+
"a": ["a", "b", "c"],
494+
"b": [1, 2, 3],
495+
"c": pd.array(["a", "b", "c"], dtype="string"),
496+
}
497+
)
498+
499+
# with "object" -> only select the object or default str dtype column
500+
result = df.select_dtypes(include=["object"])
501+
expected = df[["a"]]
502+
tm.assert_frame_equal(result, expected)
503+
504+
# with "string" -> select both the default 'str' and the nullable 'string'
505+
result = df.select_dtypes(include=["string"])
506+
if using_infer_string:
507+
expected = df[["a", "c"]]
508+
else:
509+
expected = df[["c"]]
510+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)