@@ -291,6 +291,8 @@ def agg(self) -> DataFrame | Series | None:
291
291
elif is_list_like (func ):
292
292
# we require a list, but not a 'str'
293
293
return self .agg_list_like ()
294
+ elif callable (func ):
295
+ return self .agg_callable ()
294
296
295
297
# caller can react
296
298
return None
@@ -797,6 +799,86 @@ def _apply_str(self, obj, func: str, *args, **kwargs):
797
799
msg = f"'{ func } ' is not a valid function for '{ type (obj ).__name__ } ' object"
798
800
raise AttributeError (msg )
799
801
802
+ def agg_callable (self ) -> DataFrame | Series :
803
+ """
804
+ Compute aggregation in the case of a callable argument.
805
+
806
+ This method handles callable functions while preserving extension dtypes
807
+ by delegating to the same infrastructure used for string aggregations.
808
+
809
+ Returns
810
+ -------
811
+ Result of aggregation.
812
+ """
813
+ obj = self .obj
814
+ func = self .func
815
+
816
+ if obj .ndim == 1 :
817
+ return func (obj , * self .args , ** self .kwargs )
818
+
819
+ # Use _reduce to preserve extension dtypes like on string aggregation
820
+ try :
821
+ result = obj ._reduce (
822
+ func ,
823
+ name = getattr (func , '__name__' , '<lambda>' ),
824
+ axis = self .axis ,
825
+ skipna = True ,
826
+ numeric_only = False ,
827
+ ** self .kwargs
828
+ )
829
+ return result
830
+
831
+ except (AttributeError , TypeError ):
832
+ # If _reduce fails, fallback to column-wise
833
+ return self ._agg_callable_fallback ()
834
+
835
+ def _agg_callable_fallback (self ) -> DataFrame | Series :
836
+ """
837
+ Fallback method for callable aggregation when _reduce fails.
838
+
839
+ This method applies the function column-wise while preserving dtypes,
840
+ but avoids the performance overhead of row-by-row processing.
841
+ """
842
+ obj = self .obj
843
+ func = self .func
844
+
845
+ if self .axis == 1 :
846
+ # For row-wise aggregation, transpose and recurse
847
+ transposed_result = obj .T ._aggregate (func , axis = 0 , * self .args , ** self .kwargs )
848
+ return transposed_result
849
+
850
+ from pandas import Series
851
+
852
+ try :
853
+ # Apply function to each column
854
+ results = {}
855
+ for name in obj .columns :
856
+ col = obj ._get_column_reference (name )
857
+ result_val = func (col , * self .args , ** self .kwargs )
858
+ results [name ] = result_val
859
+
860
+ result = Series (results , name = None )
861
+
862
+ # Preserve extension dtypes where possible
863
+ for name in result .index :
864
+ if name in obj .columns :
865
+ original_dtype = obj .dtypes [name ]
866
+ if hasattr (original_dtype , 'construct_array_type' ):
867
+ try :
868
+ array_type = original_dtype .construct_array_type ()
869
+ if hasattr (array_type , '_from_sequence' ):
870
+ preserved_val = array_type ._from_sequence (
871
+ [result [name ]], dtype = original_dtype
872
+ )[0 ]
873
+ result .loc [name ] = preserved_val
874
+ except Exception :
875
+ # If dtype preservation fails, keep the computed value
876
+ pass
877
+
878
+ return result
879
+
880
+ except Exception :
881
+ return None
800
882
801
883
class NDFrameApply (Apply ):
802
884
"""
0 commit comments