@@ -802,10 +802,10 @@ def _apply_str(self, obj, func: str, *args, **kwargs):
802802 def agg_callable (self ) -> DataFrame | Series :
803803 """
804804 Compute aggregation in the case of a callable argument.
805-
805+
806806 This method handles callable functions while preserving extension dtypes
807807 by delegating to the same infrastructure used for string aggregations.
808-
808+
809809 Returns
810810 -------
811811 Result of aggregation.
@@ -815,68 +815,73 @@ def agg_callable(self) -> DataFrame | Series:
815815
816816 if obj .ndim == 1 :
817817 return func (obj , * self .args , ** self .kwargs )
818-
818+
819819 # Use _reduce to preserve extension dtypes like on string aggregation
820820 try :
821821 result = obj ._reduce (
822- func ,
823- name = getattr (func , ' __name__' , ' <lambda>' ),
822+ func ,
823+ name = getattr (func , " __name__" , " <lambda>" ),
824824 axis = self .axis ,
825825 skipna = True ,
826826 numeric_only = False ,
827827 ** self .kwargs
828828 )
829829 return result
830-
830+
831831 except (AttributeError , TypeError ):
832832 # If _reduce fails, fallback to column-wise
833833 return self ._agg_callable_fallback ()
834834
835835 def _agg_callable_fallback (self ) -> DataFrame | Series :
836836 """
837837 Fallback method for callable aggregation when _reduce fails.
838-
838+
839839 This method applies the function column-wise while preserving dtypes,
840840 but avoids the performance overhead of row-by-row processing.
841841 """
842842 obj = self .obj
843843 func = self .func
844-
844+
845845 if self .axis == 1 :
846846 # For row-wise aggregation, transpose and recurse
847- transposed_result = obj .T ._aggregate (func , axis = 0 , * self .args , ** self .kwargs )
847+ transposed_result = obj .T ._aggregate (
848+ func ,
849+ * self .args ,
850+ axis = 0 ,
851+ ** self .kwargs
852+ )
848853 return transposed_result
849-
854+
850855 from pandas import Series
851-
856+
852857 try :
853858 # Apply function to each column
854859 results = {}
855860 for name in obj .columns :
856861 col = obj ._get_column_reference (name )
857862 result_val = func (col , * self .args , ** self .kwargs )
858863 results [name ] = result_val
859-
864+
860865 result = Series (results , name = None )
861-
866+
862867 # Preserve extension dtypes where possible
863868 for name in result .index :
864869 if name in obj .columns :
865870 original_dtype = obj .dtypes [name ]
866- if hasattr (original_dtype , ' construct_array_type' ):
871+ if hasattr (original_dtype , " construct_array_type" ):
867872 try :
868873 array_type = original_dtype .construct_array_type ()
869- if hasattr (array_type , ' _from_sequence' ):
874+ if hasattr (array_type , " _from_sequence" ):
870875 preserved_val = array_type ._from_sequence (
871876 [result [name ]], dtype = original_dtype
872877 )[0 ]
873878 result .loc [name ] = preserved_val
874879 except Exception :
875880 # If dtype preservation fails, keep the computed value
876881 pass
877-
882+
878883 return result
879-
884+
880885 except Exception :
881886 return None
882887
0 commit comments