Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,13 @@ def _wrap_applied_output(

result = self.obj._constructor(index=res_index, columns=data.columns)
result = result.astype(data.dtypes)

# Preserve metadata for subclassed DataFrames
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result

# GH12824
Expand All @@ -2081,13 +2088,28 @@ def _wrap_applied_output(
# GH57775 - Ensure that columns and dtypes from original frame are kept.
result = self.obj._constructor(columns=data.columns)
result = result.astype(data.dtypes)

# Preserve metadata for subclassed DataFrames
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result
elif isinstance(first_not_none, DataFrame):
return self._concat_objects(
result = self._concat_objects(
values,
not_indexed_same=not_indexed_same,
is_transform=is_transform,
)

# Preserve metadata for subclassed DataFrames
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result

key_index = self._grouper.result_index if self.as_index else None

Expand All @@ -2105,27 +2127,58 @@ def _wrap_applied_output(
# (expression has type "Hashable", variable
# has type "Tuple[Any, ...]")
name = self._selection # type: ignore[assignment]
return self.obj._constructor_sliced(values, index=key_index, name=name)
result = self.obj._constructor_sliced(values, index=key_index, name=name)

# Preserve metadata for subclassed Series
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result
elif not isinstance(first_not_none, Series):
# values are not series or array-like but scalars
# self._selection not passed through to Series as the
# result should not take the name of original selection
# of columns
if self.as_index:
return self.obj._constructor_sliced(values, index=key_index)
result = self.obj._constructor_sliced(values, index=key_index)

# Preserve metadata for subclassed Series
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result
else:
result = self.obj._constructor(values, columns=[self._selection])
result = self._insert_inaxis_grouper(result)

# Preserve metadata for subclassed DataFrames
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result
else:
# values are Series
return self._wrap_applied_output_series(
result = self._wrap_applied_output_series(
values,
not_indexed_same,
first_not_none,
key_index,
is_transform,
)

# Preserve metadata for subclassed DataFrames/Series
if hasattr(self.obj, '_metadata'):
for attr in self.obj._metadata:
if hasattr(self.obj, attr):
setattr(result, attr, getattr(self.obj, attr))

return result

def _wrap_applied_output_series(
self,
Expand Down
32 changes: 32 additions & 0 deletions pandas/tests/groupby/test_groupby_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Tests for metadata preservation in groupby operations.
"""

import numpy as np
import pytest

import pandas as pd
import pandas._testing as tm
from pandas import DataFrame
from pandas.tests.groupby import test_groupby_subclass


class TestGroupByMetadataPreservation:
def test_groupby_apply_preserves_metadata(self):
"""Test that groupby.apply() preserves _metadata from subclassed DataFrame."""
# Create a subclassed DataFrame with metadata
subdf = tm.SubclassedDataFrame(
{"X": [1, 1, 2, 2, 3], "Y": np.arange(0, 5), "Z": np.arange(10, 15)}
)
subdf.testattr = "test"

# Apply groupby operation
result = subdf.groupby("X").apply(np.sum, axis=0, include_groups=False)

# Check that metadata is preserved
assert hasattr(result, 'testattr'), "Metadata attribute 'testattr' should be preserved"
assert result.testattr == "test", "Metadata value should be preserved"

# Compare with equivalent operation that preserves metadata
expected = subdf.groupby("X").sum()
assert expected.testattr == "test", "Equivalent operation should preserve metadata"
Loading