Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.pandas/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ MultiIndex Modifying and computations
:toctree: api/

MultiIndex.equals
MultiIndex.equal_levels
MultiIndex.identical
MultiIndex.insert
MultiIndex.drop
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/pandas/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,43 @@ def intersection(self, other: Union[DataFrame, Series, Index, List]) -> "MultiIn
)
return cast(MultiIndex, DataFrame(internal).index)

def equal_levels(self, other: "MultiIndex") -> bool:
"""
Return True if the levels of both MultiIndex objects are the same

Notes
-----
This API can be expensive since it has logic to sort and compare the values of
all levels of indices that belong to MultiIndex.

Examples
--------
>>> from pyspark.pandas.config import set_option, reset_option
>>> set_option("compute.ops_on_diff_frames", True)

>>> psmidx1 = ps.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
>>> psmidx2 = ps.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")])
>>> psmidx1.equal_levels(psmidx2)
True

>>> psmidx2 = ps.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "j")])
>>> psmidx1.equal_levels(psmidx2)
False

>>> reset_option("compute.ops_on_diff_frames")
"""
nlevels = self.nlevels
if nlevels != other.nlevels:
return False

for nlevel in range(nlevels):
self_values = self.get_level_values(nlevel).unique().sort_values()
other_values = other.get_level_values(nlevel).unique().sort_values()
if not self_values.equals(other_values):
return False

return True

@property
def hasnans(self) -> bool:
raise NotImplementedError("hasnans is not defined for MultiIndex")
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,13 @@ def test_map(self):
lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}),
)

def test_multiindex_equal_levels(self):
pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
pmidx2 = pd.MultiIndex.from_tuples([("a", "x", "q"), ("b", "y", "w"), ("c", "z", "e")])
psmidx1 = ps.from_pandas(pmidx1)
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

def test_to_numpy(self):
pidx = pd.Index([1, 2, 3, 4])
psidx = ps.from_pandas(pidx)
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/pandas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,35 @@ def _test_cov(self, pser1, pser2):
pscov = psser1.cov(psser2, min_periods=3)
self.assert_eq(pcov, pscov, almost=True)

def test_multiindex_equal_levels(self):
pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
pmidx2 = pd.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")])
psmidx1 = ps.from_pandas(pmidx1)
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

pmidx2 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "j")])
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

pmidx2 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("a", "x")])
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

pmidx2 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")])
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

pmidx2 = pd.MultiIndex.from_tuples([("a", "y"), ("b", "x"), ("c", "z")])
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))

pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z"), ("a", "y")])
pmidx2 = pd.MultiIndex.from_tuples([("a", "y"), ("b", "x"), ("c", "z"), ("c", "x")])
psmidx1 = ps.from_pandas(pmidx1)
psmidx2 = ps.from_pandas(pmidx2)
self.assert_eq(pmidx1.equal_levels(pmidx2), psmidx1.equal_levels(psmidx2))


class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
Expand Down Expand Up @@ -2039,6 +2068,14 @@ def test_combine_first(self):
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psdf1.combine_first(psdf2)

def test_multiindex_equal_levels(self):
pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
pmidx2 = pd.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")])
psmidx1 = ps.from_pandas(pmidx1)
psmidx2 = ps.from_pandas(pmidx2)
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psmidx1.equal_levels(psmidx2)


if __name__ == "__main__":
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401
Expand Down