Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions modin/pandas/api/extensions/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import inspect
from collections import defaultdict
from functools import cached_property
from types import MethodType, ModuleType
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -94,6 +95,10 @@ def decorator(new_attr: Any):
original_attr = getattr(pd, name)
_reexport_classes[name] = original_attr
delattr(pd, name)
# If the attribute is an instance of functools.cached_property, we must manually call __set_name__ on it.
# https://stackoverflow.com/a/62161136
if isinstance(new_attr, cached_property):
new_attr.__set_name__(obj, name)
extensions[None if backend is None else Backend.normalize(backend)][
name
] = new_attr
Expand Down
32 changes: 19 additions & 13 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,28 @@
"_wrap_aggregation",
}

GROUPBY_EXTENSION_NO_LOOKUP = EXTENSION_NO_LOOKUP | {
"_axis",
"_idx_name",
"_df",
"_query_compiler",
"_columns",
"_by",
"_drop",
"_return_tuple_when_iterating",
"_is_multi_by",
"_level",
"_kwargs",
"_get_query_compiler",
}


@_inherit_docstrings(pandas.core.groupby.DataFrameGroupBy)
class DataFrameGroupBy(ClassLogger, QueryCompilerCaster): # noqa: GL08
_pandas_class = pandas.core.groupby.DataFrameGroupBy
_return_tuple_when_iterating = False
_df: Union[DataFrame, Series]
_query_compiler: BaseQueryCompiler
# TODO(https://github.com/modin-project/modin/issues/7543):
# Currently this _extensions dict doesn't do anything, but we should
# add methods to register groupby accessors and make the groupby classes
# use this _extensions dict.
_extensions: EXTENSION_DICT_TYPE = EXTENSION_DICT_TYPE(dict)

def __init__(
Expand Down Expand Up @@ -248,7 +259,7 @@ def __getattr__(self, key):
try:
return self._getattr__from_extension_impl(
key=key,
default_behavior_attributes=_DEFAULT_BEHAVIOUR,
default_behavior_attributes=GROUPBY_EXTENSION_NO_LOOKUP,
extensions=__class__._extensions,
)
except AttributeError as err:
Expand All @@ -275,7 +286,7 @@ def __getattribute__(self, item: str) -> Any:
Any
The value of the attribute.
"""
if item not in _DEFAULT_BEHAVIOUR:
if item not in GROUPBY_EXTENSION_NO_LOOKUP:
extensions_result = self._getattribute__from_extension_impl(
item, __class__._extensions
)
Expand Down Expand Up @@ -1863,11 +1874,6 @@ def groupby_on_multiple_columns(df, *args, **kwargs):
@_inherit_docstrings(pandas.core.groupby.SeriesGroupBy)
class SeriesGroupBy(DataFrameGroupBy): # noqa: GL08
_pandas_class = pandas.core.groupby.SeriesGroupBy

# TODO(https://github.com/modin-project/modin/issues/7543):
# Currently this _extensions dict doesn't do anything, but we should
# add methods to register groupby accessors and make the groupby classes
# use this _extensions dict.
_extensions: EXTENSION_DICT_TYPE = EXTENSION_DICT_TYPE(dict)

@disable_logging
Expand All @@ -1888,7 +1894,7 @@ def __getattribute__(self, item: str) -> Any:
Any
The value of the attribute.
"""
if item not in _DEFAULT_BEHAVIOUR:
if item not in GROUPBY_EXTENSION_NO_LOOKUP:
extensions_result = self._getattribute__from_extension_impl(
item, __class__._extensions
)
Expand All @@ -1901,7 +1907,7 @@ def __getattribute__(self, item: str) -> Any:
def __getattr__(self, key: str) -> Any:
return self._getattr__from_extension_impl(
key=key,
default_behavior_attributes=_DEFAULT_BEHAVIOUR,
default_behavior_attributes=GROUPBY_EXTENSION_NO_LOOKUP,
extensions=__class__._extensions,
)

Expand Down
12 changes: 12 additions & 0 deletions modin/tests/pandas/extensions/test_groupby_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

from functools import cached_property

import pytest

import modin.pandas as pd
Expand Down Expand Up @@ -241,6 +243,16 @@ def del_property(self):
delattr(pandas_groupby, public_property_name)
assert not hasattr(pandas_groupby, private_property_name)

@pytest.mark.filterwarnings(default_to_pandas_ignore_string)
def test_override_cached_property(self, get_groupby, register_accessor):
@cached_property
def groups(self):
return {"group": pd.Index(["test"])}

register_accessor("groups", backend="Pandas")(groups)
pandas_df = pd.DataFrame({"col0": [1], "col1": [2]}).move_to("pandas")
assert get_groupby(pandas_df).groups == {"group": pd.Index(["test"])}


def test_deleting_extension_that_is_not_property_raises_attribute_error():
expected_string_val = "Some string value"
Expand Down
Loading