Skip to content

Commit 39662d2

Browse files
fix select_dtypes test + more test changes
1 parent d816476 commit 39662d2

File tree

12 files changed

+40
-25
lines changed

12 files changed

+40
-25
lines changed

pandas/core/frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4807,7 +4807,9 @@ def select_dtypes(self, include=None, exclude=None) -> DataFrame:
48074807
-----
48084808
* To select all *numeric* types, use ``np.number`` or ``'number'``
48094809
* To select strings you must use the ``object`` dtype, but note that
4810-
this will return *all* object dtype columns
4810+
this will return *all* object dtype columns. With
4811+
``pd.options.future.infer_string`` enabled, using ``"str"`` will
4812+
work to select all string columns.
48114813
* See the `numpy dtype hierarchy
48124814
<https://numpy.org/doc/stable/reference/arrays.scalars.html>`__
48134815
* To select datetimes, use ``np.datetime64``, ``'datetime'`` or

pandas/tests/frame/methods/test_astype.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_astype_str_float(self):
186186
tm.assert_frame_equal(result, expected)
187187

188188
@pytest.mark.parametrize("dtype_class", [dict, Series])
189-
def test_astype_dict_like(self, dtype_class, using_infer_string):
189+
def test_astype_dict_like(self, dtype_class):
190190
# GH7271 & GH16717
191191
a = Series(date_range("2010-01-04", periods=5))
192192
b = Series(range(5))
@@ -201,10 +201,7 @@ def test_astype_dict_like(self, dtype_class, using_infer_string):
201201
expected = DataFrame(
202202
{
203203
"a": a,
204-
"b": Series(
205-
["0", "1", "2", "3", "4"],
206-
dtype="str" if using_infer_string else "object",
207-
),
204+
"b": Series(["0", "1", "2", "3", "4"], dtype="str"),
208205
"c": c,
209206
"d": Series([1.0, 2.0, 3.14, 4.0, 5.4], dtype="float32"),
210207
}

pandas/tests/frame/methods/test_reset_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def test_reset_index_dtypes_on_empty_frame_with_multiindex(
662662
idx = MultiIndex.from_product([[0, 1], [0.5, 1.0], array])
663663
result = DataFrame(index=idx)[:0].reset_index().dtypes
664664
if using_infer_string and dtype == object:
665-
dtype = "str"
665+
dtype = pd.StringDtype(na_value=np.nan)
666666
expected = Series({"level_0": np.int64, "level_1": np.float64, "level_2": dtype})
667667
tm.assert_series_equal(result, expected)
668668

pandas/tests/frame/methods/test_select_dtypes.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
64
from pandas.core.dtypes.dtypes import ExtensionDtype
75

86
import pandas as pd
@@ -52,7 +50,7 @@ def copy(self):
5250

5351

5452
class TestSelectDtypes:
55-
def test_select_dtypes_include_using_list_like(self):
53+
def test_select_dtypes_include_using_list_like(self, using_infer_string):
5654
df = DataFrame(
5755
{
5856
"a": list("abc"),
@@ -96,6 +94,11 @@ def test_select_dtypes_include_using_list_like(self):
9694
with pytest.raises(NotImplementedError, match=r"^$"):
9795
df.select_dtypes(include=["period"])
9896

97+
if using_infer_string:
98+
ri = df.select_dtypes(include=["str"])
99+
ei = df[["a"]]
100+
tm.assert_frame_equal(ri, ei)
101+
99102
def test_select_dtypes_exclude_using_list_like(self):
100103
df = DataFrame(
101104
{
@@ -153,7 +156,7 @@ def test_select_dtypes_exclude_include_int(self, include):
153156
expected = df[["b", "c", "e"]]
154157
tm.assert_frame_equal(result, expected)
155158

156-
def test_select_dtypes_include_using_scalars(self):
159+
def test_select_dtypes_include_using_scalars(self, using_infer_string):
157160
df = DataFrame(
158161
{
159162
"a": list("abc"),
@@ -189,6 +192,11 @@ def test_select_dtypes_include_using_scalars(self):
189192
with pytest.raises(NotImplementedError, match=r"^$"):
190193
df.select_dtypes(include="period")
191194

195+
if using_infer_string:
196+
ri = df.select_dtypes(include="str")
197+
ei = df[["a"]]
198+
tm.assert_frame_equal(ri, ei)
199+
192200
def test_select_dtypes_exclude_using_scalars(self):
193201
df = DataFrame(
194202
{
@@ -347,10 +355,12 @@ def test_select_dtypes_datetime_with_tz(self):
347355
expected = df3.reindex(columns=[])
348356
tm.assert_frame_equal(result, expected)
349357

350-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
351358
@pytest.mark.parametrize("dtype", [str, "str", np.bytes_, "S1", np.str_, "U1"])
352359
@pytest.mark.parametrize("arg", ["include", "exclude"])
353-
def test_select_dtypes_str_raises(self, dtype, arg):
360+
def test_select_dtypes_str_raises(self, dtype, arg, using_infer_string):
361+
if using_infer_string and dtype == "str":
362+
# this is tested below
363+
pytest.skip("Selecting string columns works with future strings")
354364
df = DataFrame(
355365
{
356366
"a": list("abc"),

pandas/tests/frame/test_block_internals.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def test_construction_with_mixed(self, float_string_frame, using_infer_string):
186186
expected = Series(
187187
[np.dtype("float64")] * 4
188188
+ [
189-
np.dtype("object") if not using_infer_string else "str",
189+
np.dtype("object")
190+
if not using_infer_string
191+
else pd.StringDtype(na_value=np.nan),
190192
np.dtype("datetime64[us]"),
191193
np.dtype("timedelta64[us]"),
192194
],

pandas/tests/frame/test_stack_unstack.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,11 @@ def test_unstack_dtypes(self, using_infer_string):
658658
df2["D"] = "foo"
659659
df3 = df2.unstack("B")
660660
result = df3.dtypes
661-
dtype = "str" if using_infer_string else np.dtype("object")
661+
dtype = (
662+
pd.StringDtype(na_value=np.nan)
663+
if using_infer_string
664+
else np.dtype("object")
665+
)
662666
expected = Series(
663667
[np.dtype("float64")] * 2 + [dtype] * 2,
664668
index=MultiIndex.from_arrays(

pandas/tests/groupby/test_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ def test_groupby_apply_datetime_result_dtypes(using_infer_string):
10131013
msg = "DataFrameGroupBy.apply operated on the grouping columns"
10141014
with tm.assert_produces_warning(DeprecationWarning, match=msg):
10151015
result = data.groupby("color").apply(lambda g: g.iloc[0]).dtypes
1016-
dtype = "str" if using_infer_string else object
1016+
dtype = pd.StringDtype(na_value=np.nan) if using_infer_string else object
10171017
expected = Series(
10181018
[np.dtype("datetime64[us]"), dtype, dtype, np.int64, dtype],
10191019
index=["observation", "color", "mood", "intensity", "score"],

pandas/tests/groupby/test_categorical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def f(x):
134134
result = g.apply(f)
135135
expected = x.iloc[[0, 1]].copy()
136136
expected.index = Index([1, 2], name="person_id")
137-
dtype = "str" if using_infer_string else object
138-
expected["person_name"] = expected["person_name"].astype(dtype)
137+
# dtype = "str" if using_infer_string else object
138+
# expected["person_name"] = expected["person_name"].astype(dtype)
139139
tm.assert_frame_equal(result, expected)
140140

141141

pandas/tests/indexes/multi/test_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def test_dtype_representation(using_infer_string):
850850
# GH#46900
851851
pmidx = MultiIndex.from_arrays([[1], ["a"]], names=[("a", "b"), ("c", "d")])
852852
result = pmidx.dtypes
853-
exp = "object" if not using_infer_string else "str"
853+
exp = "object" if not using_infer_string else pd.StringDtype(na_value=np.nan)
854854
expected = Series(
855855
["int64", exp],
856856
index=MultiIndex.from_tuples([("a", "b"), ("c", "d")]),

pandas/tests/indexes/object/test_indexing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_get_indexer_non_unique_np_nats(self, np_nat_fixture, np_nat_fixture2):
171171

172172

173173
class TestSliceLocs:
174+
# TODO(infer_string) parametrize over multiple string dtypes
174175
@pytest.mark.parametrize(
175176
"dtype",
176177
[
@@ -209,6 +210,7 @@ def test_slice_locs_negative_step(self, in_slice, expected, dtype):
209210
expected = Index(list(expected), dtype=dtype)
210211
tm.assert_index_equal(result, expected)
211212

213+
# TODO(infer_string) parametrize over multiple string dtypes
212214
@td.skip_if_no("pyarrow")
213215
def test_slice_locs_negative_step_oob(self):
214216
index = Index(list("bcdxy"), dtype="string[pyarrow_numpy]")

0 commit comments

Comments
 (0)