Skip to content

#61382: FIX MultiIndex.difference for pyarrow-backed Timestamps #62127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4135,6 +4135,33 @@ def isin(self, values, level=None) -> npt.NDArray[np.bool_]:
# base class "Index" defined the type as "Callable[[Index, Any, bool], Any]")
rename = Index.set_names # type: ignore[assignment]

def difference(self, other, sort=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems really unlikely to be the right place to handle this. if you step through the existing implementation, where is the first step that goes wrong?

"""
Return a new MultiIndex with elements in self that are not in other.
Fixed to work with pyarrow-backed Timestamps.
"""
if isinstance(other, type(self)):
# Convert pyarrow-backed Timestamps to pandas Timestamps for comparison
self_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
for level in self.levels]
other_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
for level in other.levels]
self_conv = pd.MultiIndex.from_arrays(self_arrays, names=self.names)
other_conv = pd.MultiIndex.from_arrays(other_arrays, names=other.names)
result = self_conv.difference(other_conv, sort=sort)
# Preserve pyarrow dtypes if present
for i, level in enumerate(self.levels):
if getattr(level, "dtype", None) == "timestamp[ns][pyarrow]":
result = pd.MultiIndex.from_arrays(
[pd.Series(arr, dtype="timestamp[ns][pyarrow]") if i==idx else arr
for idx, arr in enumerate(result.levels)],
names=result.names
)
return result
else:
return super(type(self), self).difference(other, sort=sort)


# ---------------------------------------------------------------
# Arithmetic/Numeric Methods - Disabled

Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/indexes/multi/test_timestamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd
import pytest

pytest.importorskip("pyarrow")

def test_difference_with_pyarrow_timestamp():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would go in tests/indexes/multi/test_setops.py

dates = pd.Series(
["2024-01-01", "2024-01-02"], dtype="timestamp[ns][pyarrow]"
)
ids = [1, 2]

mi = pd.MultiIndex.from_arrays([ids, dates], names=["id", "date"])
to_remove = mi[:1]

result = mi.difference(to_remove)

expected_dates = pd.Series(
["2024-01-02"], dtype="timestamp[ns][pyarrow]"
)
expected_ids = [2]
expected = pd.MultiIndex.from_arrays(
[expected_ids, expected_dates], names=["id", "date"]
)

pd.testing.assert_index_equal(result, expected)
Loading