Skip to content

Commit 2929562

Browse files
committed
Revert to a doc update
1 parent 0efbe9c commit 2929562

File tree

2 files changed

+13
-23
lines changed

2 files changed

+13
-23
lines changed

pandas/core/reshape/encoding.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
is_integer_dtype,
1818
is_list_like,
1919
is_object_dtype,
20-
is_string_dtype,
2120
pandas_dtype,
2221
)
2322
from pandas.core.dtypes.dtypes import (
2423
ArrowDtype,
2524
CategoricalDtype,
2625
)
27-
from pandas.core.dtypes.missing import isna
2826

2927
from pandas.core.arrays import SparseArray
3028
from pandas.core.arrays.categorical import factorize_from_iterable
@@ -38,7 +36,6 @@
3836

3937
if TYPE_CHECKING:
4038
from pandas._typing import (
41-
DtypeObj,
4239
NpDtype,
4340
)
4441

@@ -395,7 +392,9 @@ def from_dummies(
395392
The default category is the implied category when a value has none of the
396393
listed categories specified with a one, i.e. if all dummies in a row are
397394
zero. Can be a single value for all variables or a dict directly mapping
398-
the default categories to a prefix of a variable.
395+
the default categories to a prefix of a variable. The default category
396+
will be coerced to the dtype of ``data.columns`` if such coercion is
397+
lossless, and will raise otherwise.
399398
400399
Returns
401400
-------
@@ -560,20 +559,9 @@ def from_dummies(
560559
"Dummy DataFrame contains multi-assignment(s); "
561560
f"First instance in row: {assigned.idxmax()}"
562561
)
563-
dtype: str | DtypeObj = data.columns.dtype
564562
if any(assigned == 0):
565563
if isinstance(default_category, dict):
566-
value = default_category[prefix]
567-
if (
568-
is_string_dtype(data.columns.dtype)
569-
and not isinstance(value, str)
570-
and (is_list_like(value) or not isna(value))
571-
):
572-
# https://github.com/pandas-dev/pandas/pull/60694
573-
# `value` is not a string or NA.
574-
# Using data.columns.dtype would coerce `value` into a string.
575-
dtype = "object"
576-
cats.append(value)
564+
cats.append(default_category[prefix])
577565
else:
578566
raise ValueError(
579567
"Dummy DataFrame contains unassigned value(s); "
@@ -584,7 +572,8 @@ def from_dummies(
584572
)
585573
else:
586574
data_slice = data_to_decode.loc[:, prefix_slice]
587-
cats_array = data._constructor_sliced(cats, dtype=dtype)
575+
# cats_array = data._constructor_sliced(cats, dtype=dtype)
576+
cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype)
588577
# get indices of True entries along axis=1
589578
true_values = data_slice.idxmax(axis=1)
590579
indexer = data_slice.columns.get_indexer_for(true_values)

pandas/tests/reshape/test_from_dummies.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pytest
33

4-
import pandas as pd
54
from pandas import (
65
DataFrame,
76
Series,
@@ -334,7 +333,7 @@ def test_no_prefix_string_cats_default_category(
334333
):
335334
dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]})
336335
result = from_dummies(dummies, default_category=default_category)
337-
expected = DataFrame(expected)
336+
expected = DataFrame(expected, dtype=dummies.columns.dtype)
338337
tm.assert_frame_equal(result, expected)
339338

340339

@@ -466,14 +465,16 @@ def test_object_dtype_preserved():
466465
# https://github.com/pandas-dev/pandas/pull/60694
467466
# When the input has object dtype, the result should as
468467
# well even when infer_string is True.
468+
import pandas as pd
469+
470+
assert pd.get_option("future.infer_string")
469471
df = DataFrame(
470472
{
471473
"x": [1, 0, 0],
472474
"y": [0, 1, 0],
473475
},
474476
)
475477
df.columns = df.columns.astype("object")
476-
with pd.option_context("future.infer_string", True):
477-
result = from_dummies(df, default_category="z")
478-
expected = DataFrame({"": ["x", "y", "z"]}, dtype="object")
479-
tm.assert_frame_equal(result, expected)
478+
result = from_dummies(df, default_category="z")
479+
expected = DataFrame({"": ["x", "y", "z"]}, dtype="object")
480+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)