|
6 | 6 | import pandas.util._test_decorators as td |
7 | 7 |
|
8 | 8 | from pandas import ( |
| 9 | + NA, |
| 10 | + CategoricalDtype, |
9 | 11 | DataFrame, |
10 | 12 | Index, |
11 | 13 | MultiIndex, |
|
22 | 24 | def test_get_dummies(any_string_dtype): |
23 | 25 | s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) |
24 | 26 | result = s.str.get_dummies("|") |
25 | | - expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc")) |
| 27 | + exp_dtype = ( |
| 28 | + "boolean" |
| 29 | + if any_string_dtype == "string" and any_string_dtype.na_value is NA |
| 30 | + else "bool" |
| 31 | + ) |
| 32 | + expected = DataFrame( |
| 33 | + [[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=exp_dtype |
| 34 | + ) |
26 | 35 | tm.assert_frame_equal(result, expected) |
27 | 36 |
|
28 | 37 | s = Series(["a;b", "a", 7], dtype=any_string_dtype) |
29 | 38 | result = s.str.get_dummies(";") |
30 | | - expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab")) |
| 39 | + expected = DataFrame( |
| 40 | + [[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab"), dtype=exp_dtype |
| 41 | + ) |
31 | 42 | tm.assert_frame_equal(result, expected) |
32 | 43 |
|
33 | 44 |
|
34 | 45 | def test_get_dummies_index(): |
35 | 46 | # GH9980, GH8028 |
36 | 47 | idx = Index(["a|b", "a|c", "b|c"]) |
37 | | - result = idx.str.get_dummies("|") |
| 48 | + result = idx.str.get_dummies("|", dtype=np.int64) |
38 | 49 |
|
39 | 50 | expected = MultiIndex.from_tuples( |
40 | 51 | [(1, 1, 0), (1, 0, 1), (0, 1, 1)], names=("a", "b", "c") |
@@ -125,3 +136,15 @@ def test_get_dummies_with_pa_str_dtype(any_string_dtype): |
125 | 136 | dtype="str[pyarrow]", |
126 | 137 | ) |
127 | 138 | tm.assert_frame_equal(result, expected) |
| 139 | + |
| 140 | + |
| 141 | +@pytest.mark.parametrize("dtype_type", ["string", "category"]) |
| 142 | +def test_get_dummies_ea_dtype(dtype_type, string_dtype_no_object): |
| 143 | + dtype = string_dtype_no_object |
| 144 | + exp_dtype = "boolean" if dtype.na_value is NA else "bool" |
| 145 | + if dtype_type == "category": |
| 146 | + dtype = CategoricalDtype(Index(["a", "b"], dtype)) |
| 147 | + s = Series(["a", "b"], dtype=dtype) |
| 148 | + result = s.str.get_dummies() |
| 149 | + expected = DataFrame([[1, 0], [0, 1]], columns=list("ab"), dtype=exp_dtype) |
| 150 | + tm.assert_frame_equal(result, expected) |
0 commit comments