@@ -860,7 +860,7 @@ def filter(self, column_id: str, keep_null: bool = False):
860860
861861 def aggregate_all_and_stack (
862862 self ,
863- operation : agg_ops .AggregateOp ,
863+ operation : agg_ops .UnaryAggregateOp ,
864864 * ,
865865 axis : int | str = 0 ,
866866 value_col_id : str = "values" ,
@@ -872,7 +872,8 @@ def aggregate_all_and_stack(
872872 axis_n = utils .get_axis_number (axis )
873873 if axis_n == 0 :
874874 aggregations = [
875- (col_id , operation , col_id ) for col_id in self .value_columns
875+ (ex .UnaryAggregation (operation , ex .free_var (col_id )), col_id )
876+ for col_id in self .value_columns
876877 ]
877878 index_col_ids = [
878879 guid .generate_guid () for i in range (self .column_labels .nlevels )
@@ -902,10 +903,13 @@ def aggregate_all_and_stack(
902903 dtype = dtype ,
903904 )
904905 index_aggregations = [
905- (col_id , agg_ops .AnyValueOp (), col_id )
906+ (ex . UnaryAggregation ( agg_ops .AnyValueOp (), ex . free_var ( col_id ) ), col_id )
906907 for col_id in [* self .index_columns ]
907908 ]
908- main_aggregation = (value_col_id , operation , value_col_id )
909+ main_aggregation = (
910+ ex .UnaryAggregation (operation , ex .free_var (value_col_id )),
911+ value_col_id ,
912+ )
909913 result_expr = stacked_expr .aggregate (
910914 [* index_aggregations , main_aggregation ],
911915 by_column_ids = [offset_col ],
@@ -966,7 +970,7 @@ def remap_f(x):
966970 def aggregate (
967971 self ,
968972 by_column_ids : typing .Sequence [str ] = (),
969- aggregations : typing .Sequence [typing .Tuple [str , agg_ops .AggregateOp ]] = (),
973+ aggregations : typing .Sequence [typing .Tuple [str , agg_ops .UnaryAggregateOp ]] = (),
970974 * ,
971975 dropna : bool = True ,
972976 ) -> typing .Tuple [Block , typing .Sequence [str ]]:
@@ -979,10 +983,13 @@ def aggregate(
979983 dropna: whether null keys should be dropped
980984 """
981985 agg_specs = [
982- (input_id , operation , guid .generate_guid ())
986+ (
987+ ex .UnaryAggregation (operation , ex .free_var (input_id )),
988+ guid .generate_guid (),
989+ )
983990 for input_id , operation in aggregations
984991 ]
985- output_col_ids = [agg_spec [2 ] for agg_spec in agg_specs ]
992+ output_col_ids = [agg_spec [1 ] for agg_spec in agg_specs ]
986993 result_expr = self .expr .aggregate (agg_specs , by_column_ids , dropna = dropna )
987994
988995 aggregate_labels = self ._get_labels_for_columns (
@@ -1004,7 +1011,7 @@ def aggregate(
10041011 output_col_ids ,
10051012 )
10061013
1007- def get_stat (self , column_id : str , stat : agg_ops .AggregateOp ):
1014+ def get_stat (self , column_id : str , stat : agg_ops .UnaryAggregateOp ):
10081015 """Gets aggregates immediately, and caches it"""
10091016 if stat .name in self ._stats_cache [column_id ]:
10101017 return self ._stats_cache [column_id ][stat .name ]
@@ -1014,7 +1021,10 @@ def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
10141021 standard_stats = self ._standard_stats (column_id )
10151022 stats_to_fetch = standard_stats if stat in standard_stats else [stat ]
10161023
1017- aggregations = [(column_id , stat , stat .name ) for stat in stats_to_fetch ]
1024+ aggregations = [
1025+ (ex .UnaryAggregation (stat , ex .free_var (column_id )), stat .name )
1026+ for stat in stats_to_fetch
1027+ ]
10181028 expr = self .expr .aggregate (aggregations )
10191029 offset_index_id = guid .generate_guid ()
10201030 expr = expr .promote_offsets (offset_index_id )
@@ -1054,13 +1064,13 @@ def get_corr_stat(self, column_id_left: str, column_id_right: str):
10541064 def summarize (
10551065 self ,
10561066 column_ids : typing .Sequence [str ],
1057- stats : typing .Sequence [agg_ops .AggregateOp ],
1067+ stats : typing .Sequence [agg_ops .UnaryAggregateOp ],
10581068 ):
10591069 """Get a list of stats as a deferred block object."""
10601070 label_col_id = guid .generate_guid ()
10611071 labels = [stat .name for stat in stats ]
10621072 aggregations = [
1063- (col_id , stat , f"{ col_id } -{ stat .name } " )
1073+ (ex . UnaryAggregation ( stat , ex . free_var ( col_id )) , f"{ col_id } -{ stat .name } " )
10641074 for stat in stats
10651075 for col_id in column_ids
10661076 ]
@@ -1076,7 +1086,7 @@ def summarize(
10761086 labels = self ._get_labels_for_columns (column_ids )
10771087 return Block (expr , column_labels = labels , index_columns = [label_col_id ])
10781088
1079- def _standard_stats (self , column_id ) -> typing .Sequence [agg_ops .AggregateOp ]:
1089+ def _standard_stats (self , column_id ) -> typing .Sequence [agg_ops .UnaryAggregateOp ]:
10801090 """
10811091 Gets a standard set of stats to preemptively fetch for a column if
10821092 any other stat is fetched.
@@ -1087,7 +1097,7 @@ def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
10871097 """
10881098 # TODO: annotate aggregations themself with this information
10891099 dtype = self .expr .get_column_type (column_id )
1090- stats : list [agg_ops .AggregateOp ] = [agg_ops .count_op ]
1100+ stats : list [agg_ops .UnaryAggregateOp ] = [agg_ops .count_op ]
10911101 if dtype not in bigframes .dtypes .UNORDERED_DTYPES :
10921102 stats += [agg_ops .min_op , agg_ops .max_op ]
10931103 if dtype in bigframes .dtypes .NUMERIC_BIGFRAMES_TYPES_PERMISSIVE :
0 commit comments