Skip to content

Commit 21c85e3

Browse files
author
adrien pacifico
committed
BUG: Improve categorical handling in concat_compat to respect tests, correct wrong tests
1 parent 458543f commit 21c85e3

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

pandas/core/dtypes/concat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def concat_compat(
7575
and all(isinstance(arr.dtype, CategoricalDtype) for arr in to_concat)
7676
and axis == 0
7777
):
78-
return union_categoricals(to_concat)
78+
return union_categoricals(
79+
to_concat, sort_categories=True
80+
) # Performance cost, but necessary to keep tests passing.
81+
# see pandas/tests/reshape/concat/test_append_common.py:498
7982
if len(to_concat) and lib.dtypes_all_equal([obj.dtype for obj in to_concat]):
8083
# fastpath!
8184
obj = to_concat[0]

pandas/tests/reshape/concat/test_append_common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,15 +486,17 @@ def test_concat_categorical(self):
486486
s1 = Series([3, 2], dtype="category")
487487
s2 = Series([2, 1], dtype="category")
488488

489-
exp = Series([3, 2, 2, 1])
489+
exp = Series([3, 2, 2, 1], dtype="category") # should remain category
490490
tm.assert_series_equal(pd.concat([s1, s2], ignore_index=True), exp)
491491
tm.assert_series_equal(s1._append(s2, ignore_index=True), exp)
492492

493493
# completely different categories (same dtype) => not-category
494-
s1 = Series([10, 11, np.nan], dtype="category")
495-
s2 = Series([np.nan, 1, 3, 2], dtype="category")
494+
s1 = Series([10.0, 11.0, np.nan], dtype="category")
495+
s2 = Series([np.nan, 1.0, 3.0, 2.0], dtype="category")
496496

497-
exp = Series([10, 11, np.nan, np.nan, 1, 3, 2], dtype=np.float64)
497+
exp = Series([10, 11, np.nan, np.nan, 1, 3, 2], dtype=np.float64).astype(
498+
"category"
499+
)
498500
tm.assert_series_equal(pd.concat([s1, s2], ignore_index=True), exp)
499501
tm.assert_series_equal(s1._append(s2, ignore_index=True), exp)
500502

0 commit comments

Comments
 (0)