Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,29 +1308,9 @@ def _transform_general(self, func, *args, **kwargs):
gen = self.grouper.get_iterator(obj, axis=self.axis)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

def process_result(
group: DataFrame | Series, res: DataFrame | Series
) -> DataFrame | Series:
if isinstance(res, Series):
# we need to broadcast across the
# other dimension; this will preserve dtypes
# GH14457
if res.index.is_(obj.index):
r = concat([res] * len(group.columns), axis=1)
r.columns = group.columns
r.index = group.index
else:
r = self.obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(
group.shape
),
columns=group.columns,
index=group.index,
)
return r
else:
return res

# Determine whether to use slow or fast path by evaluating on the first group.
# Need to handle the case of an empty generator and process the result so that
# it does not need to be computed again.
try:
name, group = next(gen)
except StopIteration:
Expand All @@ -1345,14 +1325,17 @@ def process_result(
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err
if group.size > 0:
applied.append(process_result(group, res))
res = _wrap_transform_general_frame(self.obj, group, res)
applied.append(res)

# Compute and process with the remaining groups
for name, group in gen:
if group.size == 0:
continue
object.__setattr__(group, "name", name)
res = path(group)
applied.append(process_result(group, res))
res = _wrap_transform_general_frame(self.obj, group, res)
applied.append(res)

concat_index = obj.columns if self.axis == 0 else obj.index
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
Expand Down Expand Up @@ -1863,3 +1846,28 @@ def func(df):
return self._python_apply_general(func, self._obj_with_exclusions)

boxplot = boxplot_frame_groupby


def _wrap_transform_general_frame(
obj: DataFrame, group: DataFrame, res: DataFrame | Series
) -> DataFrame:
from pandas import concat

if isinstance(res, Series):
# we need to broadcast across the
# other dimension; this will preserve dtypes
# GH14457
if res.index.is_(obj.index):
r = concat([res] * len(group.columns), axis=1)
r.columns = group.columns
r.index = group.index
else:
r = obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(group.shape),
columns=group.columns,
index=group.index,
)
assert isinstance(r, DataFrame)
return r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call this e.g. res_frame? i.e. not 1 letter

else:
return res