Skip to content

Commit 1e59822

Browse files
author
samuel.oranyeli
committed
create groupby accessor and method
1 parent a3bd755 commit 1e59822

File tree

4 files changed

+160
-1
lines changed

4 files changed

+160
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
rev: 1.5.0
1818
hooks:
1919
- id: interrogate
20-
args: [-c, pyproject.toml, -vv]
20+
args: [-c, pyproject.toml, -vv, --fail-under=70]
2121
- repo: https://github.com/terrencepreilly/darglint
2222
rev: v1.8.1
2323
hooks:

pandas_flavor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
register_dataframe_method,
55
register_series_accessor,
66
register_series_method,
7+
register_groupby_accessor,
8+
register_groupby_method,
79
)
810
from .xarray import (
911
register_xarray_dataarray_method,
@@ -15,6 +17,8 @@
1517
"register_series_accessor",
1618
"register_dataframe_method",
1719
"register_dataframe_accessor",
20+
"register_groupby_accessor",
21+
"register_groupby_method",
1822
"register_xarray_dataarray_method",
1923
"register_xarray_dataset_method",
2024
]

pandas_flavor/register.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
"""Register functions as methods of Pandas DataFrame and Series."""
2+
from __future__ import annotations
3+
4+
import warnings
25
from functools import wraps
6+
from typing import Callable
7+
8+
from pandas.core.groupby.generic import DataFrameGroupBy
9+
from pandas.util._exceptions import find_stack_level
310
from pandas.api.extensions import (
411
register_series_accessor,
512
register_dataframe_accessor,
@@ -228,3 +235,124 @@ def __call__(self, *args, **kwargs):
228235
return method
229236

230237
return inner()
238+
239+
240+
# variant of pandas' accessor
241+
242+
# copied from pandas' accessor file - pandas/pandas/core/accessor.py
243+
"""
244+
245+
accessor.py contains base classes for implementing accessor properties
246+
that can be mixed into or pinned onto other pandas classes.
247+
248+
"""
249+
250+
251+
class CachedAccessor:
252+
"""
253+
Custom property-like object.
254+
255+
A descriptor for caching accessors.
256+
257+
Parameters
258+
----------
259+
name : str
260+
Namespace that will be accessed under, e.g. ``df.foo``.
261+
accessor : DataFrameGroupBy
262+
Class with the extension methods.
263+
264+
Notes
265+
-----
266+
For accessor, The class's __init__ method assumes that one of
267+
``Series``, ``DataFrame`` or ``Index`` as the
268+
single argument ``data``.
269+
"""
270+
271+
def __init__(self, name: str, accessor: DataFrameGroupBy) -> None:
272+
self._name = name
273+
self._accessor = accessor
274+
275+
def __get__(self, obj, cls):
276+
if obj is None:
277+
# we're accessing the attribute of the class, i.e., Dataset.geo
278+
return self._accessor
279+
accessor_obj = self._accessor(obj)
280+
# Replace the property with the accessor object. Inspired by:
281+
# https://www.pydanny.com/cached-property.html
282+
# We need to use object.__setattr__ because we overwrite __setattr__ on
283+
# NDFrame
284+
object.__setattr__(obj, self._name, accessor_obj)
285+
return accessor_obj
286+
287+
288+
def _register_accessor(name: str, cls: DataFrameGroupBy) -> Callable:
289+
"""
290+
Register a custom accessor on DataFrameGroupBy objects.
291+
292+
Args:
293+
name : str
294+
Name under which the accessor should be registered.
295+
A warning is issued
296+
if this name conflicts with a preexisting attribute.
297+
cls: DataFrameGroupBy
298+
299+
Returns:
300+
A class decorator.
301+
"""
302+
303+
def decorator(accessor):
304+
if hasattr(cls, name):
305+
warnings.warn(
306+
f"registration of accessor {repr(accessor)} under name "
307+
f"{repr(name)} for type {repr(cls)} "
308+
"is overriding a preexisting "
309+
f"attribute with the same name.",
310+
UserWarning,
311+
stacklevel=find_stack_level(),
312+
)
313+
setattr(cls, name, CachedAccessor(name, accessor))
314+
if not hasattr(cls, "_accessors"):
315+
cls._accessors = set()
316+
cls._accessors.add(name)
317+
return accessor
318+
319+
return decorator
320+
321+
322+
def register_groupby_accessor(name: str):
323+
return _register_accessor(name, DataFrameGroupBy)
324+
325+
326+
def register_groupby_method(method: Callable) -> Callable:
327+
"""Register a function as a method attached to the pandas DataFrameGroupBy.
328+
329+
Example:
330+
>>> @register_groupby_method # doctest: +SKIP
331+
>>> def print_column(grp, col): # doctest: +SKIP
332+
... '''Print the dataframe column given''' # doctest: +SKIP
333+
... print(grp[col]) # doctest: +SKIP
334+
335+
!!! info "New in version 0.7.0"
336+
337+
Args:
338+
method: Function to be registered as a method on the DataFrame.
339+
340+
Returns:
341+
A Callable.
342+
"""
343+
344+
def inner(*args: tuple, **kwargs: dict):
345+
class AccessorMethod(object):
346+
__doc__ = method.__doc__
347+
348+
def __init__(self, obj):
349+
self._obj = obj
350+
351+
@wraps(method)
352+
def __call__(self, *args, **kwargs):
353+
return method(self._obj, *args, **kwargs)
354+
355+
register_groupby_accessor(method.__name__)(AccessorMethod)
356+
return method
357+
358+
return inner()

tests/test_pandas_register.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tests for pandas series and dataframe method registration."""
2+
23
import pandas_flavor as pf
34
import pandas as pd
5+
from pandas.core.groupby.generic import DataFrameGroupBy
46

57

68
def test_register_dataframe_method():
@@ -39,3 +41,28 @@ def dummy_func(s: pd.Series) -> pd.Series:
3941

4042
ser = pd.Series()
4143
ser.dummy_func()
44+
45+
46+
def test_register_groupby_method():
47+
"""Test register_groupby_method."""
48+
49+
@pf.register_groupby_method
50+
def dummy_func(by: DataFrameGroupBy) -> DataFrameGroupBy:
51+
"""Dummy func.
52+
53+
Args:
54+
by: A DataFrameGroupBy object.
55+
56+
Returns:
57+
DataFrameGroupBy.
58+
"""
59+
return by
60+
61+
df = pd.DataFrame(
62+
{
63+
"Animal": ["Falcon"],
64+
"Max Speed": [380.0],
65+
}
66+
)
67+
by = df.groupby("Animal")
68+
by.dummy_func()

0 commit comments

Comments
 (0)