Skip to content

Commit 38f3784

Browse files
committed
Fix DataFrame.aggregate to preserve extension dtypes with callable functions
1 parent d5f97ed commit 38f3784

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

pandas/core/apply.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def agg(self) -> DataFrame | Series | None:
291291
elif is_list_like(func):
292292
# we require a list, but not a 'str'
293293
return self.agg_list_like()
294+
elif callable(func):
295+
return self.agg_callable()
294296

295297
# caller can react
296298
return None
@@ -797,6 +799,86 @@ def _apply_str(self, obj, func: str, *args, **kwargs):
797799
msg = f"'{func}' is not a valid function for '{type(obj).__name__}' object"
798800
raise AttributeError(msg)
799801

802+
def agg_callable(self) -> DataFrame | Series:
803+
"""
804+
Compute aggregation in the case of a callable argument.
805+
806+
This method handles callable functions while preserving extension dtypes
807+
by delegating to the same infrastructure used for string aggregations.
808+
809+
Returns
810+
-------
811+
Result of aggregation.
812+
"""
813+
obj = self.obj
814+
func = self.func
815+
816+
if obj.ndim == 1:
817+
return func(obj, *self.args, **self.kwargs)
818+
819+
# Use _reduce to preserve extension dtypes like on string aggregation
820+
try:
821+
result = obj._reduce(
822+
func,
823+
name=getattr(func, '__name__', '<lambda>'),
824+
axis=self.axis,
825+
skipna=True,
826+
numeric_only=False,
827+
**self.kwargs
828+
)
829+
return result
830+
831+
except (AttributeError, TypeError):
832+
# If _reduce fails, fallback to column-wise
833+
return self._agg_callable_fallback()
834+
835+
def _agg_callable_fallback(self) -> DataFrame | Series:
836+
"""
837+
Fallback method for callable aggregation when _reduce fails.
838+
839+
This method applies the function column-wise while preserving dtypes,
840+
but avoids the performance overhead of row-by-row processing.
841+
"""
842+
obj = self.obj
843+
func = self.func
844+
845+
if self.axis == 1:
846+
# For row-wise aggregation, transpose and recurse
847+
transposed_result = obj.T._aggregate(func, axis=0, *self.args, **self.kwargs)
848+
return transposed_result
849+
850+
from pandas import Series
851+
852+
try:
853+
# Apply function to each column
854+
results = {}
855+
for name in obj.columns:
856+
col = obj._get_column_reference(name)
857+
result_val = func(col, *self.args, **self.kwargs)
858+
results[name] = result_val
859+
860+
result = Series(results, name=None)
861+
862+
# Preserve extension dtypes where possible
863+
for name in result.index:
864+
if name in obj.columns:
865+
original_dtype = obj.dtypes[name]
866+
if hasattr(original_dtype, 'construct_array_type'):
867+
try:
868+
array_type = original_dtype.construct_array_type()
869+
if hasattr(array_type, '_from_sequence'):
870+
preserved_val = array_type._from_sequence(
871+
[result[name]], dtype=original_dtype
872+
)[0]
873+
result.loc[name] = preserved_val
874+
except Exception:
875+
# If dtype preservation fails, keep the computed value
876+
pass
877+
878+
return result
879+
880+
except Exception:
881+
return None
800882

801883
class NDFrameApply(Apply):
802884
"""

0 commit comments

Comments
 (0)