Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Deprecations
Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`)
-
- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions (:issue:`41598`)

.. ---------------------------------------------------------------------------

Expand Down
70 changes: 44 additions & 26 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,42 +1308,35 @@ def _transform_general(self, func, *args, **kwargs):
gen = self.grouper.get_iterator(obj, axis=self.axis)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

for name, group in gen:
if group.size == 0:
continue
# Determine whether to use slow or fast path by evaluating on the first group.
# Need to handle the case of an empty generator and process the result so that
# it does not need to be computed again.
try:
name, group = next(gen)
except StopIteration:
pass
else:
object.__setattr__(group, "name", name)

# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err

if isinstance(res, Series):

# we need to broadcast across the
# other dimension; this will preserve dtypes
# GH14457
if res.index.is_(obj.index):
r = concat([res] * len(group.columns), axis=1)
r.columns = group.columns
r.index = group.index
else:
r = self.obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(
group.shape
),
columns=group.columns,
index=group.index,
)

applied.append(r)
else:
if group.size > 0:
res = _wrap_transform_general_frame(self.obj, group, res)
applied.append(res)

# Compute and process with the remaining groups
for name, group in gen:
if group.size == 0:
continue
object.__setattr__(group, "name", name)
res = path(group)
res = _wrap_transform_general_frame(self.obj, group, res)
applied.append(res)

concat_index = obj.columns if self.axis == 0 else obj.index
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
Expand Down Expand Up @@ -1853,3 +1846,28 @@ def func(df):
return self._python_apply_general(func, self._obj_with_exclusions)

boxplot = boxplot_frame_groupby


def _wrap_transform_general_frame(
obj: DataFrame, group: DataFrame, res: DataFrame | Series
) -> DataFrame:
from pandas import concat

if isinstance(res, Series):
# we need to broadcast across the
# other dimension; this will preserve dtypes
# GH14457
if res.index.is_(obj.index):
r = concat([res] * len(group.columns), axis=1)
r.columns = group.columns
r.index = group.index
else:
r = obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(group.shape),
columns=group.columns,
index=group.index,
)
assert isinstance(r, DataFrame)
return r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call this e.g. res_frame? i.e. not 1 letter

else:
return res