17
17
is_integer_dtype ,
18
18
is_list_like ,
19
19
is_object_dtype ,
20
- is_string_dtype ,
21
20
pandas_dtype ,
22
21
)
23
22
from pandas .core .dtypes .dtypes import (
24
23
ArrowDtype ,
25
24
CategoricalDtype ,
26
25
)
27
- from pandas .core .dtypes .missing import isna
28
26
29
27
from pandas .core .arrays import SparseArray
30
28
from pandas .core .arrays .categorical import factorize_from_iterable
38
36
39
37
if TYPE_CHECKING :
40
38
from pandas ._typing import (
41
- DtypeObj ,
42
39
NpDtype ,
43
40
)
44
41
@@ -395,7 +392,9 @@ def from_dummies(
395
392
The default category is the implied category when a value has none of the
396
393
listed categories specified with a one, i.e. if all dummies in a row are
397
394
zero. Can be a single value for all variables or a dict directly mapping
398
- the default categories to a prefix of a variable.
395
+ the default categories to a prefix of a variable. The default category
396
+ will be coerced to the dtype of ``data.columns`` if such coercion is
397
+ lossless, and will raise otherwise.
399
398
400
399
Returns
401
400
-------
@@ -560,20 +559,9 @@ def from_dummies(
560
559
"Dummy DataFrame contains multi-assignment(s); "
561
560
f"First instance in row: { assigned .idxmax ()} "
562
561
)
563
- dtype : str | DtypeObj = data .columns .dtype
564
562
if any (assigned == 0 ):
565
563
if isinstance (default_category , dict ):
566
- value = default_category [prefix ]
567
- if (
568
- is_string_dtype (data .columns .dtype )
569
- and not isinstance (value , str )
570
- and (is_list_like (value ) or not isna (value ))
571
- ):
572
- # https://github.com/pandas-dev/pandas/pull/60694
573
- # `value` is not a string or NA.
574
- # Using data.columns.dtype would coerce `value` into a string.
575
- dtype = "object"
576
- cats .append (value )
564
+ cats .append (default_category [prefix ])
577
565
else :
578
566
raise ValueError (
579
567
"Dummy DataFrame contains unassigned value(s); "
@@ -584,7 +572,8 @@ def from_dummies(
584
572
)
585
573
else :
586
574
data_slice = data_to_decode .loc [:, prefix_slice ]
587
- cats_array = data ._constructor_sliced (cats , dtype = dtype )
575
+ # cats_array = data._constructor_sliced(cats, dtype=dtype)
576
+ cats_array = data ._constructor_sliced (cats , dtype = data .columns .dtype )
588
577
# get indices of True entries along axis=1
589
578
true_values = data_slice .idxmax (axis = 1 )
590
579
indexer = data_slice .columns .get_indexer_for (true_values )
0 commit comments