Skip to content

Commit 7a22663

Browse files
committed
TST(string dtype): Resolve xfails in test_from_dummies
1 parent a81d52f commit 7a22663

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

pandas/core/reshape/encoding.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
is_integer_dtype,
1818
is_list_like,
1919
is_object_dtype,
20+
is_string_dtype,
2021
pandas_dtype,
2122
)
2223
from pandas.core.dtypes.dtypes import (
2324
ArrowDtype,
2425
CategoricalDtype,
2526
)
27+
from pandas.core.dtypes.missing import isna
2628

2729
from pandas.core.arrays import SparseArray
2830
from pandas.core.arrays.categorical import factorize_from_iterable
@@ -554,9 +556,20 @@ def from_dummies(
554556
"Dummy DataFrame contains multi-assignment(s); "
555557
f"First instance in row: {assigned.idxmax()}"
556558
)
559+
dtype = data.columns.dtype
557560
if any(assigned == 0):
558561
if isinstance(default_category, dict):
559-
cats.append(default_category[prefix])
562+
value = default_category[prefix]
563+
if (
564+
is_string_dtype(data.columns.dtype)
565+
and not isinstance(value, str)
566+
and (is_list_like(value) or not isna(value))
567+
):
568+
# GH#???
569+
# `value` is not a string or NA.
570+
# Using data.columns.dtype would coerce `value` into a string.
571+
dtype = "object"
572+
cats.append(value)
560573
else:
561574
raise ValueError(
562575
"Dummy DataFrame contains unassigned value(s); "
@@ -567,7 +580,7 @@ def from_dummies(
567580
)
568581
else:
569582
data_slice = data_to_decode.loc[:, prefix_slice]
570-
cats_array = data._constructor_sliced(cats, dtype=data.columns.dtype)
583+
cats_array = data._constructor_sliced(cats, dtype=dtype)
571584
# get indices of True entries along axis=1
572585
true_values = data_slice.idxmax(axis=1)
573586
indexer = data_slice.columns.get_indexer_for(true_values)

pandas/tests/reshape/test_from_dummies.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from pandas._config import using_string_dtype
5-
4+
import pandas as pd
65
from pandas import (
76
DataFrame,
87
Series,
@@ -336,8 +335,6 @@ def test_no_prefix_string_cats_default_category(
336335
dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]})
337336
result = from_dummies(dummies, default_category=default_category)
338337
expected = DataFrame(expected)
339-
if using_infer_string:
340-
expected[""] = expected[""].astype("str")
341338
tm.assert_frame_equal(result, expected)
342339

343340

@@ -364,7 +361,6 @@ def test_with_prefix_contains_get_dummies_NaN_column():
364361
tm.assert_frame_equal(result, expected)
365362

366363

367-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
368364
@pytest.mark.parametrize(
369365
"default_category, expected",
370366
[
@@ -450,3 +446,32 @@ def test_maintain_original_index():
450446
result = from_dummies(df)
451447
expected = DataFrame({"": list("abca")}, index=list("abcd"))
452448
tm.assert_frame_equal(result, expected)
449+
450+
451+
def test_int_columns_with_float_default():
452+
# GH#???
453+
df = DataFrame(
454+
{
455+
3: [1, 0, 0],
456+
4: [0, 1, 0],
457+
},
458+
)
459+
with pytest.raises(ValueError, match="Trying to coerce float values to integers"):
460+
from_dummies(df, default_category=0.5)
461+
462+
463+
def test_object_dtype_preserved():
464+
# GH#???
465+
# When the input has object dtype, the result should as
466+
# well even when infer_string is True.
467+
df = DataFrame(
468+
{
469+
"x": [1, 0, 0],
470+
"y": [0, 1, 0],
471+
},
472+
)
473+
df.columns = df.columns.astype("object")
474+
with pd.option_context("future.infer_string", True):
475+
result = from_dummies(df, default_category="z")
476+
expected = DataFrame({"": ["x", "y", "z"]}, dtype="object")
477+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)