Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Deprecations
Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`)
-
- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions (:issue:`41598`)

.. ---------------------------------------------------------------------------

Expand Down
46 changes: 28 additions & 18 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,22 +1308,10 @@ def _transform_general(self, func, *args, **kwargs):
gen = self.grouper.get_iterator(obj, axis=self.axis)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

for name, group in gen:
if group.size == 0:
continue
object.__setattr__(group, "name", name)

# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err

def process_result(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be a module level function?

Copy link
Member Author

@rhshadrach rhshadrach Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - can either make it a module-level by passing self.obj or make this into a class method. Do you see a reason to prefer one or the other?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think prefer a module level

group: DataFrame | Series, res: DataFrame | Series
) -> DataFrame | Series:
if isinstance(res, Series):

# we need to broadcast across the
# other dimension; this will preserve dtypes
# GH14457
Expand All @@ -1339,10 +1327,32 @@ def _transform_general(self, func, *args, **kwargs):
columns=group.columns,
index=group.index,
)

applied.append(r)
return r
else:
applied.append(res)
return res

try:
name, group = next(gen)
except StopIteration:
pass
else:
object.__setattr__(group, "name", name)
try:
path, res = self._choose_path(fast_path, slow_path, group)
except TypeError:
return self._transform_item_by_item(obj, fast_path)
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err
if group.size > 0:
applied.append(process_result(group, res))

for name, group in gen:
if group.size == 0:
continue
object.__setattr__(group, "name", name)
res = path(group)
applied.append(process_result(group, res))

concat_index = obj.columns if self.axis == 0 else obj.index
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
Expand Down