Skip to content

Commit 21f2e70

Browse files
committed
change default dtype of str.get_dummies() to bool
1 parent a8a84c8 commit 21f2e70

File tree

5 files changed

+73
-19
lines changed

5 files changed

+73
-19
lines changed

pandas/core/arrays/string_arrow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,19 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
397397

398398
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
399399
if dtype is None:
400-
dtype = np.int64
400+
dtype = np.bool_
401401
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
402402
sep, dtype
403403
)
404-
if len(labels) == 0:
405-
return np.empty(shape=(0, 0), dtype=dtype), labels
406-
dummies = np.vstack(dummies_pa.to_numpy())
407404
_dtype = pandas_dtype(dtype)
408405
dummies_dtype: NpDtype
409406
if isinstance(_dtype, np.dtype):
410407
dummies_dtype = _dtype
411408
else:
412409
dummies_dtype = np.bool_
410+
if len(labels) == 0:
411+
return np.empty(shape=(0, 0), dtype=dummies_dtype), labels
412+
dummies = np.vstack(dummies_pa.to_numpy())
413413
return dummies.astype(dummies_dtype, copy=False), labels
414414

415415
def _convert_int_result(self, result):

pandas/core/strings/accessor.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,7 @@ def get_dummies(
24892489
----------
24902490
sep : str, default "|"
24912491
String to split on.
2492-
dtype : dtype, default np.int64
2492+
dtype : dtype, default bool
24932493
Data type for new columns. Only a single dtype is allowed.
24942494
24952495
Returns
@@ -2505,27 +2505,48 @@ def get_dummies(
25052505
Examples
25062506
--------
25072507
>>> pd.Series(["a|b", "a", "a|c"]).str.get_dummies()
2508-
a b c
2509-
0 1 1 0
2510-
1 1 0 0
2511-
2 1 0 1
2508+
a b c
2509+
0 True True False
2510+
1 True False False
2511+
2 True False True
25122512
25132513
>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies()
2514+
a b c
2515+
0 True True False
2516+
1 False False False
2517+
2 True False True
2518+
2519+
>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=np.int64)
25142520
a b c
25152521
0 1 1 0
25162522
1 0 0 0
25172523
2 1 0 1
2518-
2519-
>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
2520-
a b c
2521-
0 True True False
2522-
1 False False False
2523-
2 True False True
25242524
"""
25252525
from pandas.core.frame import DataFrame
25262526

25272527
# we need to cast to Series of strings as only that has all
25282528
# methods available for making the dummies...
2529+
input_dtype = self._data.dtype
2530+
if dtype is None and not isinstance(input_dtype, ArrowDtype):
2531+
from pandas.core.arrays.string_ import StringDtype
2532+
2533+
if isinstance(input_dtype, CategoricalDtype):
2534+
input_dtype = input_dtype.categories.dtype
2535+
2536+
if isinstance(input_dtype, ArrowDtype):
2537+
import pyarrow as pa
2538+
2539+
dtype = ArrowDtype(pa.bool_())
2540+
elif (
2541+
isinstance(input_dtype, StringDtype)
2542+
and input_dtype.na_value is not np.nan
2543+
):
2544+
from pandas.core.dtypes.common import pandas_dtype
2545+
2546+
dtype = pandas_dtype("boolean")
2547+
else:
2548+
dtype = np.bool_
2549+
25292550
result, name = self._data.array._str_get_dummies(sep, dtype)
25302551
if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype):
25312552
return self._wrap_result(

pandas/core/strings/object_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
416416
from pandas import Series
417417

418418
if dtype is None:
419-
dtype = np.int64
419+
dtype = np.bool_
420420
arr = Series(self).fillna("")
421421
try:
422422
arr = sep + arr + sep

pandas/tests/extension/test_arrow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,16 @@ def test_get_dummies():
22062206
)
22072207
tm.assert_frame_equal(result, expected)
22082208

2209+
ser = pd.Series(
2210+
["a", "b"],
2211+
dtype=pd.CategoricalDtype(pd.Index(["a", "b"], dtype=ArrowDtype(pa.string()))),
2212+
)
2213+
result = ser.str.get_dummies()
2214+
expected = pd.DataFrame(
2215+
[[True, False], [False, True]], dtype=ArrowDtype(pa.bool_()), columns=["a", "b"]
2216+
)
2217+
tm.assert_frame_equal(result, expected)
2218+
22092219

22102220
def test_str_partition():
22112221
ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))

pandas/tests/strings/test_get_dummies.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pandas.util._test_decorators as td
77

88
from pandas import (
9+
NA,
10+
CategoricalDtype,
911
DataFrame,
1012
Index,
1113
MultiIndex,
@@ -22,19 +24,28 @@
2224
def test_get_dummies(any_string_dtype):
2325
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
2426
result = s.str.get_dummies("|")
25-
expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"))
27+
exp_dtype = (
28+
"boolean"
29+
if any_string_dtype == "string" and any_string_dtype.na_value is NA
30+
else "bool"
31+
)
32+
expected = DataFrame(
33+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=exp_dtype
34+
)
2635
tm.assert_frame_equal(result, expected)
2736

2837
s = Series(["a;b", "a", 7], dtype=any_string_dtype)
2938
result = s.str.get_dummies(";")
30-
expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab"))
39+
expected = DataFrame(
40+
[[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab"), dtype=exp_dtype
41+
)
3142
tm.assert_frame_equal(result, expected)
3243

3344

3445
def test_get_dummies_index():
3546
# GH9980, GH8028
3647
idx = Index(["a|b", "a|c", "b|c"])
37-
result = idx.str.get_dummies("|")
48+
result = idx.str.get_dummies("|", dtype=np.int64)
3849

3950
expected = MultiIndex.from_tuples(
4051
[(1, 1, 0), (1, 0, 1), (0, 1, 1)], names=("a", "b", "c")
@@ -125,3 +136,15 @@ def test_get_dummies_with_pa_str_dtype(any_string_dtype):
125136
dtype="str[pyarrow]",
126137
)
127138
tm.assert_frame_equal(result, expected)
139+
140+
141+
@pytest.mark.parametrize("dtype_type", ["string", "category"])
142+
def test_get_dummies_ea_dtype(dtype_type, string_dtype_no_object):
143+
dtype = string_dtype_no_object
144+
exp_dtype = "boolean" if dtype.na_value is NA else "bool"
145+
if dtype_type == "category":
146+
dtype = CategoricalDtype(Index(["a", "b"], dtype))
147+
s = Series(["a", "b"], dtype=dtype)
148+
result = s.str.get_dummies()
149+
expected = DataFrame([[1, 0], [0, 1]], columns=list("ab"), dtype=exp_dtype)
150+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)