15
15
from __future__ import annotations
16
16
17
17
import typing
18
- from typing import Sequence , Union
18
+ from typing import Sequence , Tuple , Union
19
19
20
20
import bigframes_vendored .constants as constants
21
21
import bigframes_vendored .pandas .core .groupby as vendored_pandas_groupby
26
26
import bigframes .core as core
27
27
import bigframes .core .block_transforms as block_ops
28
28
import bigframes .core .blocks as blocks
29
+ import bigframes .core .expression
29
30
import bigframes .core .ordering as order
30
31
import bigframes .core .utils as utils
31
32
import bigframes .core .validations as validations
@@ -334,24 +335,19 @@ def agg(self, func=None, **kwargs) -> typing.Union[df.DataFrame, series.Series]:
334
335
return self ._agg_named (** kwargs )
335
336
336
337
def _agg_string (self , func : str ) -> df .DataFrame :
337
- aggregations = [
338
- (col_id , agg_ops .lookup_agg_func (func ))
339
- for col_id in self ._aggregated_columns ()
340
- ]
338
+ ids , labels = self ._aggregated_columns ()
339
+ aggregations = [agg (col_id , agg_ops .lookup_agg_func (func )) for col_id in ids ]
341
340
agg_block , _ = self ._block .aggregate (
342
341
by_column_ids = self ._by_col_ids ,
343
342
aggregations = aggregations ,
344
343
dropna = self ._dropna ,
344
+ column_labels = labels ,
345
345
)
346
346
dataframe = df .DataFrame (agg_block )
347
347
return dataframe if self ._as_index else self ._convert_index (dataframe )
348
348
349
349
def _agg_dict (self , func : typing .Mapping ) -> df .DataFrame :
350
- aggregations : typing .List [
351
- typing .Tuple [
352
- str , typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
353
- ]
354
- ] = []
350
+ aggregations : typing .List [bigframes .core .expression .Aggregation ] = []
355
351
column_labels = []
356
352
357
353
want_aggfunc_level = any (utils .is_list_like (aggs ) for aggs in func .values ())
@@ -362,7 +358,7 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
362
358
funcs_for_id if utils .is_list_like (funcs_for_id ) else [funcs_for_id ]
363
359
)
364
360
for f in func_list :
365
- aggregations .append ((col_id , agg_ops .lookup_agg_func (f )))
361
+ aggregations .append (agg (col_id , agg_ops .lookup_agg_func (f )))
366
362
column_labels .append (label )
367
363
agg_block , _ = self ._block .aggregate (
368
364
by_column_ids = self ._by_col_ids ,
@@ -373,7 +369,10 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
373
369
agg_block = agg_block .with_column_labels (
374
370
utils .combine_indices (
375
371
pd .Index (column_labels ),
376
- pd .Index (agg [1 ].name for agg in aggregations ),
372
+ pd .Index (
373
+ typing .cast (agg_ops .AggregateOp , agg .op ).name
374
+ for agg in aggregations
375
+ ),
377
376
)
378
377
)
379
378
else :
@@ -382,34 +381,21 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
382
381
return dataframe if self ._as_index else self ._convert_index (dataframe )
383
382
384
383
def _agg_list (self , func : typing .Sequence ) -> df .DataFrame :
384
+ ids , labels = self ._aggregated_columns ()
385
385
aggregations = [
386
- (col_id , agg_ops .lookup_agg_func (f ))
387
- for col_id in self ._aggregated_columns ()
388
- for f in func
386
+ agg (col_id , agg_ops .lookup_agg_func (f )) for col_id in ids for f in func
389
387
]
390
388
391
389
if self ._block .column_labels .nlevels > 1 :
392
390
# Restructure MultiIndex for proper format: (idx1, idx2, func)
393
391
# rather than ((idx1, idx2), func).
394
- aggregated_columns = pd .MultiIndex .from_tuples (
395
- [
396
- self ._block .col_id_to_label [col_id ]
397
- for col_id in self ._aggregated_columns ()
398
- ],
399
- names = [* self ._block .column_labels .names ],
400
- ).to_frame (index = False )
401
-
402
392
column_labels = [
403
- tuple (col_id ) + (f ,)
404
- for col_id in aggregated_columns .to_numpy ()
405
- for f in func
406
- ]
407
- else :
408
- column_labels = [
409
- (self ._block .col_id_to_label [col_id ], f )
410
- for col_id in self ._aggregated_columns ()
393
+ tuple (label ) + (f ,)
394
+ for label in labels .to_frame (index = False ).to_numpy ()
411
395
for f in func
412
396
]
397
+ else : # Single-level index
398
+ column_labels = [(label , f ) for label in labels for f in func ]
413
399
414
400
agg_block , _ = self ._block .aggregate (
415
401
by_column_ids = self ._by_col_ids ,
@@ -435,7 +421,7 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
435
421
if not isinstance (v , tuple ) or (len (v ) != 2 ):
436
422
raise TypeError ("kwargs values must be 2-tuples of column, aggfunc" )
437
423
col_id = self ._resolve_label (v [0 ])
438
- aggregations .append ((col_id , agg_ops .lookup_agg_func (v [1 ])))
424
+ aggregations .append (agg (col_id , agg_ops .lookup_agg_func (v [1 ])))
439
425
column_labels .append (k )
440
426
agg_block , _ = self ._block .aggregate (
441
427
by_column_ids = self ._by_col_ids ,
@@ -470,15 +456,19 @@ def _raise_on_non_numeric(self, op: str):
470
456
)
471
457
return self
472
458
473
- def _aggregated_columns (self , numeric_only : bool = False ) -> typing .Sequence [str ]:
459
+ def _aggregated_columns (
460
+ self , numeric_only : bool = False
461
+ ) -> Tuple [typing .Sequence [str ], pd .Index ]:
474
462
valid_agg_cols : list [str ] = []
475
- for col_id in self ._selected_cols :
463
+ offsets : list [int ] = []
464
+ for i , col_id in enumerate (self ._block .value_columns ):
476
465
is_numeric = (
477
466
self ._column_type (col_id ) in dtypes .NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
478
467
)
479
- if is_numeric or not numeric_only :
468
+ if (col_id in self ._selected_cols ) and (is_numeric or not numeric_only ):
469
+ offsets .append (i )
480
470
valid_agg_cols .append (col_id )
481
- return valid_agg_cols
471
+ return valid_agg_cols , self . _block . column_labels . take ( offsets )
482
472
483
473
def _column_type (self , col_id : str ) -> dtypes .Dtype :
484
474
col_offset = self ._block .value_columns .index (col_id )
@@ -488,11 +478,12 @@ def _column_type(self, col_id: str) -> dtypes.Dtype:
488
478
def _aggregate_all (
489
479
self , aggregate_op : agg_ops .UnaryAggregateOp , numeric_only : bool = False
490
480
) -> df .DataFrame :
491
- aggregated_col_ids = self ._aggregated_columns (numeric_only = numeric_only )
492
- aggregations = [(col_id , aggregate_op ) for col_id in aggregated_col_ids ]
481
+ aggregated_col_ids , labels = self ._aggregated_columns (numeric_only = numeric_only )
482
+ aggregations = [agg (col_id , aggregate_op ) for col_id in aggregated_col_ids ]
493
483
result_block , _ = self ._block .aggregate (
494
484
by_column_ids = self ._by_col_ids ,
495
485
aggregations = aggregations ,
486
+ column_labels = labels ,
496
487
dropna = self ._dropna ,
497
488
)
498
489
dataframe = df .DataFrame (result_block )
@@ -508,7 +499,7 @@ def _apply_window_op(
508
499
window_spec = window or window_specs .cumulative_rows (
509
500
grouping_keys = tuple (self ._by_col_ids )
510
501
)
511
- columns = self ._aggregated_columns (numeric_only = numeric_only )
502
+ columns , _ = self ._aggregated_columns (numeric_only = numeric_only )
512
503
block , result_ids = self ._block .multi_apply_window_op (
513
504
columns , op , window_spec = window_spec
514
505
)
@@ -639,11 +630,11 @@ def prod(self, *args) -> series.Series:
639
630
def agg (self , func = None ) -> typing .Union [df .DataFrame , series .Series ]:
640
631
column_names : list [str ] = []
641
632
if isinstance (func , str ):
642
- aggregations = [(self ._value_column , agg_ops .lookup_agg_func (func ))]
633
+ aggregations = [agg (self ._value_column , agg_ops .lookup_agg_func (func ))]
643
634
column_names = [func ]
644
635
elif utils .is_list_like (func ):
645
636
aggregations = [
646
- (self ._value_column , agg_ops .lookup_agg_func (f )) for f in func
637
+ agg (self ._value_column , agg_ops .lookup_agg_func (f )) for f in func
647
638
]
648
639
column_names = list (func )
649
640
else :
@@ -756,7 +747,7 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
756
747
def _aggregate (self , aggregate_op : agg_ops .UnaryAggregateOp ) -> series .Series :
757
748
result_block , _ = self ._block .aggregate (
758
749
self ._by_col_ids ,
759
- ((self ._value_column , aggregate_op ),),
750
+ (agg (self ._value_column , aggregate_op ),),
760
751
dropna = self ._dropna ,
761
752
)
762
753
@@ -781,3 +772,13 @@ def _apply_window_op(
781
772
window_spec = window_spec ,
782
773
)
783
774
return series .Series (block .select_column (result_id ))
775
+
776
+
777
+ def agg (input : str , op : agg_ops .AggregateOp ) -> bigframes .core .expression .Aggregation :
778
+ if isinstance (op , agg_ops .UnaryAggregateOp ):
779
+ return bigframes .core .expression .UnaryAggregation (
780
+ op , bigframes .core .expression .deref (input )
781
+ )
782
+ else :
783
+ assert isinstance (op , agg_ops .NullaryAggregateOp )
784
+ return bigframes .core .expression .NullaryAggregation (op )
0 commit comments