Skip to content

Commit 382da83

Browse files
committed
fix(DataFrame.unstack): fix bug when indexes contains nan
Fix bux when indexes contains `nan` and is not sorting would raise an `IndexError` or `ValueError`.
1 parent 2f26644 commit 382da83

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

pandas/core/reshape/reshape.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ def __init__(
128128

129129
self.level = self.index._get_level_number(level)
130130

131-
# when index includes `nan`, need to lift levels/strides by 1
132-
self.lift = 1 if -1 in self.index.codes[self.level] else 0
131+
# `nan` values have code `-1`, when sorting, we lift to assign them
132+
# at index 0
133+
self.has_nan = -1 in self.index.codes[self.level]
134+
should_lift = self.has_nan and self.sort
135+
self.lift = 1 if should_lift else 0
133136

134137
# Note: the "pop" below alters these in-place.
135138
self.new_index_levels = list(self.index.levels)
@@ -138,8 +141,22 @@ def __init__(
138141
self.removed_name = self.new_index_names.pop(self.level)
139142
self.removed_level = self.new_index_levels.pop(self.level)
140143
self.removed_level_full = index.levels[self.level]
144+
self.unique_nan_index: int = -1
141145
if not self.sort:
142146
unique_codes = unique(self.index.codes[self.level])
147+
if self.has_nan:
148+
# drop nan codes, because they are not represented in level
149+
nan_mask = unique_codes == -1
150+
151+
if TYPE_CHECKING:
152+
# make explicit that nan_mask is an array
153+
# to remove this pyright diagnostic:
154+
# The method "__invert__" in class "bool" is deprecated
155+
nan_mask = cast(ArrayLike, nan_mask)
156+
157+
unique_codes = unique_codes[~nan_mask]
158+
self.unique_nan_index = np.flatnonzero(nan_mask)[0]
159+
143160
self.removed_level = self.removed_level.take(unique_codes)
144161
self.removed_level_full = self.removed_level_full.take(unique_codes)
145162

@@ -210,7 +227,7 @@ def _make_selectors(self) -> None:
210227
ngroups = len(obs_ids)
211228

212229
comp_index = ensure_platform_int(comp_index)
213-
stride = self.index.levshape[self.level] + self.lift
230+
stride = self.index.levshape[self.level] + self.has_nan
214231
self.full_shape = ngroups, stride
215232

216233
selector = self.sorted_labels[-1] + stride * comp_index + self.lift
@@ -362,13 +379,13 @@ def get_new_values(self, values, fill_value=None):
362379

363380
def get_new_columns(self, value_columns: Index | None):
364381
if value_columns is None:
365-
if self.lift == 0:
382+
if not self.has_nan:
366383
return self.removed_level._rename(name=self.removed_name)
367384

368385
lev = self.removed_level.insert(0, item=self.removed_level._na_value)
369386
return lev.rename(self.removed_name)
370387

371-
stride = len(self.removed_level) + self.lift
388+
stride = len(self.removed_level) + self.has_nan
372389
width = len(value_columns)
373390
propagator = np.repeat(np.arange(width), stride)
374391

@@ -401,12 +418,18 @@ def _repeater(self) -> np.ndarray:
401418
if len(self.removed_level_full) != len(self.removed_level):
402419
# In this case, we remap the new codes to the original level:
403420
repeater = self.removed_level_full.get_indexer(self.removed_level)
404-
if self.lift:
421+
if self.has_nan:
422+
# insert nan index at first position
405423
repeater = np.insert(repeater, 0, -1)
406424
else:
407425
# Otherwise, we just use each level item exactly once:
408-
stride = len(self.removed_level) + self.lift
426+
stride = len(self.removed_level) + self.has_nan
409427
repeater = np.arange(stride) - self.lift
428+
if self.has_nan and self.lift == 0:
429+
# assign -1 where should be nan according to the unique values.
430+
repeater[self.unique_nan_index] = -1
431+
# compensate for the removed index level
432+
repeater[self.unique_nan_index + 1 :] -= 1
410433

411434
return repeater
412435

pandas/tests/frame/test_stack_unstack.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,46 @@ def test_unstack_sort_false(frame_or_series, dtype):
13861386
tm.assert_frame_equal(result, expected)
13871387

13881388

1389+
@pytest.mark.parametrize(
1390+
"levels2, expected_columns, expected_data",
1391+
[
1392+
(
1393+
Index([None, 1, 2, 3]),
1394+
[("value", np.nan), ("value", 1.0), ("value", 2.0), ("value", 3.0)],
1395+
[[0, 4], [1, 5], [2, 6], [3, 7]],
1396+
),
1397+
(
1398+
Index([1, None, 2, 3]),
1399+
[("value", 1.0), ("value", np.nan), ("value", 2.0), ("value", 3.0)],
1400+
[[0, 4], [1, 5], [2, 6], [3, 7]],
1401+
),
1402+
(
1403+
Index([1, 2, None, 3]),
1404+
[("value", 1.0), ("value", 2.0), ("value", np.nan), ("value", 3.0)],
1405+
[[0, 4], [1, 5], [2, 6], [3, 7]],
1406+
),
1407+
(
1408+
Index([1, 2, 3, None]),
1409+
[("value", 1.0), ("value", 2.0), ("value", 3.0), ("value", np.nan)],
1410+
[[0, 4], [1, 5], [2, 6], [3, 7]],
1411+
),
1412+
],
1413+
ids=["nan=first", "nan=second", "nan=third", "nan=last"],
1414+
)
1415+
def test_unstack_sort_false_nan(levels2, expected_columns, expected_data):
1416+
# GH#61221
1417+
levels1 = ["b", "a"]
1418+
index = MultiIndex.from_product([levels1, levels2], names=["level1", "level2"])
1419+
df = DataFrame({"value": [0, 1, 2, 3, 4, 5, 6, 7]}, index=index)
1420+
result = df.unstack(level="level2", sort=False)
1421+
expected = DataFrame(
1422+
dict(zip(expected_columns, expected_data)),
1423+
index=Index(["b", "a"], name="level1"),
1424+
columns=MultiIndex.from_tuples(expected_columns, names=[None, "level2"]),
1425+
)
1426+
tm.assert_frame_equal(result, expected)
1427+
1428+
13891429
def test_unstack_fill_frame_object():
13901430
# GH12815 Test unstacking with object.
13911431
data = Series(["a", "b", "c", "a"], dtype="object")

0 commit comments

Comments
 (0)