Skip to content

Commit 0e979ef

Browse files
committed
#61382: FIX MultiIndex.difference for pyarrow-backed Timestamps
- Override MultiIndex.difference to handle Timestamp[ns][pyarrow] levels - Ensure proper comparison without converting all levels to pandas types - Add pytest test for difference with pyarrow-backed MultiIndex
1 parent 1bd75cc commit 0e979ef

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

pandas/core/indexes/multi.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4135,6 +4135,33 @@ def isin(self, values, level=None) -> npt.NDArray[np.bool_]:
41354135
# base class "Index" defined the type as "Callable[[Index, Any, bool], Any]")
41364136
rename = Index.set_names # type: ignore[assignment]
41374137

4138+
def difference(self, other, sort=None):
4139+
"""
4140+
Return a new MultiIndex with elements in self that are not in other.
4141+
Fixed to work with pyarrow-backed Timestamps.
4142+
"""
4143+
if isinstance(other, type(self)):
4144+
# Convert pyarrow-backed Timestamps to pandas Timestamps for comparison
4145+
self_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
4146+
for level in self.levels]
4147+
other_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
4148+
for level in other.levels]
4149+
self_conv = pd.MultiIndex.from_arrays(self_arrays, names=self.names)
4150+
other_conv = pd.MultiIndex.from_arrays(other_arrays, names=other.names)
4151+
result = self_conv.difference(other_conv, sort=sort)
4152+
# Preserve pyarrow dtypes if present
4153+
for i, level in enumerate(self.levels):
4154+
if getattr(level, "dtype", None) == "timestamp[ns][pyarrow]":
4155+
result = pd.MultiIndex.from_arrays(
4156+
[pd.Series(arr, dtype="timestamp[ns][pyarrow]") if i==idx else arr
4157+
for idx, arr in enumerate(result.levels)],
4158+
names=result.names
4159+
)
4160+
return result
4161+
else:
4162+
return super(type(self), self).difference(other, sort=sort)
4163+
4164+
41384165
# ---------------------------------------------------------------
41394166
# Arithmetic/Numeric Methods - Disabled
41404167

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pandas as pd
2+
import pytest
3+
4+
pytest.importorskip("pyarrow")
5+
6+
def test_difference_with_pyarrow_timestamp():
7+
dates = pd.Series(
8+
["2024-01-01", "2024-01-02"], dtype="timestamp[ns][pyarrow]"
9+
)
10+
ids = [1, 2]
11+
12+
mi = pd.MultiIndex.from_arrays([ids, dates], names=["id", "date"])
13+
to_remove = mi[:1]
14+
15+
result = mi.difference(to_remove)
16+
17+
expected_dates = pd.Series(
18+
["2024-01-02"], dtype="timestamp[ns][pyarrow]"
19+
)
20+
expected_ids = [2]
21+
expected = pd.MultiIndex.from_arrays(
22+
[expected_ids, expected_dates], names=["id", "date"]
23+
)
24+
25+
pd.testing.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)