@@ -802,10 +802,10 @@ def _apply_str(self, obj, func: str, *args, **kwargs):
802
802
def agg_callable (self ) -> DataFrame | Series :
803
803
"""
804
804
Compute aggregation in the case of a callable argument.
805
-
805
+
806
806
This method handles callable functions while preserving extension dtypes
807
807
by delegating to the same infrastructure used for string aggregations.
808
-
808
+
809
809
Returns
810
810
-------
811
811
Result of aggregation.
@@ -815,68 +815,73 @@ def agg_callable(self) -> DataFrame | Series:
815
815
816
816
if obj .ndim == 1 :
817
817
return func (obj , * self .args , ** self .kwargs )
818
-
818
+
819
819
# Use _reduce to preserve extension dtypes like on string aggregation
820
820
try :
821
821
result = obj ._reduce (
822
- func ,
823
- name = getattr (func , ' __name__' , ' <lambda>' ),
822
+ func ,
823
+ name = getattr (func , " __name__" , " <lambda>" ),
824
824
axis = self .axis ,
825
825
skipna = True ,
826
826
numeric_only = False ,
827
827
** self .kwargs
828
828
)
829
829
return result
830
-
830
+
831
831
except (AttributeError , TypeError ):
832
832
# If _reduce fails, fallback to column-wise
833
833
return self ._agg_callable_fallback ()
834
834
835
835
def _agg_callable_fallback (self ) -> DataFrame | Series :
836
836
"""
837
837
Fallback method for callable aggregation when _reduce fails.
838
-
838
+
839
839
This method applies the function column-wise while preserving dtypes,
840
840
but avoids the performance overhead of row-by-row processing.
841
841
"""
842
842
obj = self .obj
843
843
func = self .func
844
-
844
+
845
845
if self .axis == 1 :
846
846
# For row-wise aggregation, transpose and recurse
847
- transposed_result = obj .T ._aggregate (func , axis = 0 , * self .args , ** self .kwargs )
847
+ transposed_result = obj .T ._aggregate (
848
+ func ,
849
+ * self .args ,
850
+ axis = 0 ,
851
+ ** self .kwargs
852
+ )
848
853
return transposed_result
849
-
854
+
850
855
from pandas import Series
851
-
856
+
852
857
try :
853
858
# Apply function to each column
854
859
results = {}
855
860
for name in obj .columns :
856
861
col = obj ._get_column_reference (name )
857
862
result_val = func (col , * self .args , ** self .kwargs )
858
863
results [name ] = result_val
859
-
864
+
860
865
result = Series (results , name = None )
861
-
866
+
862
867
# Preserve extension dtypes where possible
863
868
for name in result .index :
864
869
if name in obj .columns :
865
870
original_dtype = obj .dtypes [name ]
866
- if hasattr (original_dtype , ' construct_array_type' ):
871
+ if hasattr (original_dtype , " construct_array_type" ):
867
872
try :
868
873
array_type = original_dtype .construct_array_type ()
869
- if hasattr (array_type , ' _from_sequence' ):
874
+ if hasattr (array_type , " _from_sequence" ):
870
875
preserved_val = array_type ._from_sequence (
871
876
[result [name ]], dtype = original_dtype
872
877
)[0 ]
873
878
result .loc [name ] = preserved_val
874
879
except Exception :
875
880
# If dtype preservation fails, keep the computed value
876
881
pass
877
-
882
+
878
883
return result
879
-
884
+
880
885
except Exception :
881
886
return None
882
887
0 commit comments