Skip to content

Commit 2000f18

Browse files
StringMethods get_dummies defers to pd.get_dummies
1 parent 566e592 commit 2000f18

File tree

4 files changed

+144
-37
lines changed

4 files changed

+144
-37
lines changed

pandas/core/strings/accessor.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,8 +2356,22 @@ def wrap(
23562356
)
23572357
return self._wrap_result(result)
23582358

2359+
from collections.abc import Iterable
2360+
from typing import TYPE_CHECKING
2361+
2362+
if TYPE_CHECKING:
2363+
from pandas._typing import NpDtype
2364+
23592365
@forbid_nonstring_types(["bytes"])
2360-
def get_dummies(self, sep: str = "|"):
2366+
def get_dummies(
2367+
self,
2368+
sep: str = "|",
2369+
prefix=None,
2370+
prefix_sep: str | Iterable[str] | dict[str, str] = "_",
2371+
dummy_na: bool = False,
2372+
sparse: bool = False,
2373+
dtype: NpDtype | None = int,
2374+
):
23612375
"""
23622376
Return DataFrame of dummy/indicator variables for Series.
23632377
@@ -2395,13 +2409,67 @@ def get_dummies(self, sep: str = "|"):
23952409
"""
23962410
# we need to cast to Series of strings as only that has all
23972411
# methods available for making the dummies...
2398-
result, name = self._data.array._str_get_dummies(sep)
2399-
return self._wrap_result(
2400-
result,
2401-
name=name,
2402-
expand=True,
2403-
returns_string=False,
2412+
# result, name = self._data.array._str_get_dummies(sep)
2413+
# return self._wrap_result(
2414+
# result,
2415+
# name=name,
2416+
# expand=True,
2417+
# returns_string=False,
2418+
# )
2419+
from pandas import (
2420+
MultiIndex,
2421+
Series,
2422+
)
2423+
from pandas.core.reshape.encoding import get_dummies
2424+
2425+
input_series = Series(self._data) if isinstance(self._data, ABCIndex) else self._data
2426+
string_series = input_series.apply(lambda x: str(x) if not isna(x) else x)
2427+
split_series = string_series.str.split(sep, expand=True).stack()
2428+
valid_split_series = split_series[
2429+
(split_series.astype(str) != 'None') &
2430+
~(split_series.index.get_level_values(0).duplicated(keep='first') & split_series.isna())
2431+
]
2432+
2433+
dummy_df = get_dummies(
2434+
valid_split_series,
2435+
None,
2436+
None,
2437+
dummy_na,
2438+
None,
2439+
sparse,
2440+
False,
2441+
dtype
24042442
)
2443+
grouped_dummies = dummy_df.groupby(level=0)
2444+
if dtype == bool:
2445+
result_df = grouped_dummies.any()
2446+
else:
2447+
result_df = grouped_dummies.sum()
2448+
2449+
if isinstance(prefix, str):
2450+
result_df.columns = [f"{prefix}{prefix_sep}{col}" for col in result_df.columns]
2451+
elif isinstance(prefix, dict):
2452+
if len(prefix) != len(result_df.columns):
2453+
len_msg = (
2454+
f"Length of 'prefix' ({len(prefix)}) did not match the "
2455+
"length of the columns being encoded "
2456+
f"({len(result_df.columns)})."
2457+
)
2458+
raise ValueError(len_msg)
2459+
result_df.columns = [f"{prefix[col]}{prefix_sep}{col}" for col in result_df.columns]
2460+
elif isinstance(prefix, list):
2461+
if len(prefix) != len(result_df.columns):
2462+
len_msg = (
2463+
f"Length of 'prefix' ({len(prefix)}) did not match the "
2464+
"length of the columns being encoded "
2465+
f"({len(result_df.columns)})."
2466+
)
2467+
raise ValueError(len_msg)
2468+
result_df.columns = [f"{prefix[i]}{prefix_sep}{col}" for i, col in enumerate(result_df.columns)]
2469+
2470+
if isinstance(self._data, ABCIndex):
2471+
return MultiIndex.from_frame(result_df)
2472+
return result_df
24052473

24062474
@forbid_nonstring_types(["bytes"])
24072475
def translate(self, table):

pandas/core/strings/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,6 @@ def _str_translate(self, table):
160160
def _str_wrap(self, width: int, **kwargs):
161161
pass
162162

163-
@abc.abstractmethod
164-
def _str_get_dummies(self, sep: str = "|"):
165-
pass
166-
167163
@abc.abstractmethod
168164
def _str_isalnum(self):
169165
pass

pandas/core/strings/object_array.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -372,32 +372,6 @@ def _str_wrap(self, width: int, **kwargs):
372372
tw = textwrap.TextWrapper(**kwargs)
373373
return self._str_map(lambda s: "\n".join(tw.wrap(s)))
374374

375-
def _str_get_dummies(self, sep: str = "|"):
376-
from pandas import Series
377-
378-
arr = Series(self).fillna("")
379-
try:
380-
arr = sep + arr + sep
381-
except (TypeError, NotImplementedError):
382-
arr = sep + arr.astype(str) + sep
383-
384-
tags: set[str] = set()
385-
for ts in Series(arr, copy=False).str.split(sep):
386-
tags.update(ts)
387-
tags2 = sorted(tags - {""})
388-
389-
dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
390-
391-
def _isin(test_elements: str, element: str) -> bool:
392-
return element in test_elements
393-
394-
for i, t in enumerate(tags2):
395-
pat = sep + t + sep
396-
dummies[:, i] = lib.map_infer(
397-
arr.to_numpy(), functools.partial(_isin, element=pat)
398-
)
399-
return dummies, tags2
400-
401375
def _str_upper(self):
402376
return self._str_map(lambda x: x.upper())
403377

pandas/tests/strings/test_get_dummies.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Index,
66
MultiIndex,
77
Series,
8+
SparseDtype,
89
_testing as tm,
910
)
1011

@@ -51,3 +52,71 @@ def test_get_dummies_with_name_dummy_index():
5152
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")
5253
)
5354
tm.assert_index_equal(result, expected)
55+
56+
def test_get_dummies_with_prefix(any_string_dtype):
57+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
58+
result = s.str.get_dummies(sep="|", prefix="prefix")
59+
expected = DataFrame(
60+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
61+
columns=["prefix_a", "prefix_b", "prefix_c"],
62+
)
63+
tm.assert_frame_equal(result, expected)
64+
65+
66+
def test_get_dummies_with_prefix_sep(any_string_dtype):
67+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
68+
result = s.str.get_dummies(sep="|", prefix=None, prefix_sep="__")
69+
expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=["a", "b", "c"])
70+
tm.assert_frame_equal(result, expected)
71+
72+
result = s.str.get_dummies(sep="|", prefix="col", prefix_sep="__")
73+
expected = DataFrame(
74+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
75+
columns=["col__a", "col__b", "col__c"],
76+
)
77+
tm.assert_frame_equal(result, expected)
78+
79+
80+
def test_get_dummies_with_dummy_na(any_string_dtype):
81+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
82+
result = s.str.get_dummies(sep="|", dummy_na=True)
83+
expected = DataFrame(
84+
[[1, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 1]],
85+
columns=["a", "b", "c", np.nan],
86+
)
87+
tm.assert_frame_equal(result, expected)
88+
89+
90+
def test_get_dummies_with_sparse(any_string_dtype):
91+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
92+
result = s.str.get_dummies(sep="|", sparse=True)
93+
expected = DataFrame(
94+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
95+
columns=["a", "b", "c"],
96+
dtype="Sparse[int]",
97+
)
98+
tm.assert_frame_equal(result, expected)
99+
assert all(isinstance(dtype, SparseDtype) for dtype in result.dtypes)
100+
101+
102+
def test_get_dummies_with_dtype(any_string_dtype):
103+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
104+
result = s.str.get_dummies(sep="|", dtype=bool)
105+
expected = DataFrame(
106+
[[True, True, False], [True, False, True], [False, False, False]],
107+
columns=["a", "b", "c"],
108+
)
109+
tm.assert_frame_equal(result, expected)
110+
assert (result.dtypes == bool).all()
111+
112+
113+
def test_get_dummies_with_prefix_dict(any_string_dtype):
114+
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
115+
prefix = {"a": "alpha", "b": "beta", "c": "gamma"}
116+
result = s.str.get_dummies(sep="|", prefix=prefix)
117+
expected = DataFrame(
118+
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
119+
columns=["alpha_a", "beta_b", "gamma_c"],
120+
)
121+
tm.assert_frame_equal(result, expected)
122+

0 commit comments

Comments
 (0)