Skip to content

Commit c644df8

Browse files
authored
refactor: add compile_aggregate (#1904)
* refactor: add compile_aggregate * resolve aggregation nodes for dtype and support dropna * generate more compact aggregation SQL
1 parent f377cf6 commit c644df8

File tree

14 files changed

+354
-24
lines changed

14 files changed

+354
-24
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import functools
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import expression, window_spec
22+
from bigframes.core.compile.sqlglot.expressions import typed_expr
23+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
24+
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
25+
import bigframes.operations as ops
26+
27+
28+
def compile_aggregate(
29+
aggregate: expression.Aggregation,
30+
order_by: tuple[sge.Expression, ...],
31+
) -> sge.Expression:
32+
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
33+
if isinstance(aggregate, expression.NullaryAggregation):
34+
return compile_nullary_agg(aggregate.op)
35+
if isinstance(aggregate, expression.UnaryAggregation):
36+
column = typed_expr.TypedExpr(
37+
scalar_compiler.compile_scalar_expression(aggregate.arg),
38+
aggregate.arg.output_type,
39+
)
40+
if not aggregate.op.order_independent:
41+
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by)
42+
else:
43+
return compile_unary_agg(aggregate.op, column)
44+
elif isinstance(aggregate, expression.BinaryAggregation):
45+
left = typed_expr.TypedExpr(
46+
scalar_compiler.compile_scalar_expression(aggregate.left),
47+
aggregate.left.output_type,
48+
)
49+
right = typed_expr.TypedExpr(
50+
scalar_compiler.compile_scalar_expression(aggregate.right),
51+
aggregate.right.output_type,
52+
)
53+
return compile_binary_agg(aggregate.op, left, right)
54+
else:
55+
raise ValueError(f"Unexpected aggregation: {aggregate}")
56+
57+
58+
@functools.singledispatch
59+
def compile_nullary_agg(
60+
op: ops.aggregations.WindowOp,
61+
window: typing.Optional[window_spec.WindowSpec] = None,
62+
) -> sge.Expression:
63+
raise ValueError(f"Can't compile unrecognized operation: {op}")
64+
65+
66+
@functools.singledispatch
67+
def compile_binary_agg(
68+
op: ops.aggregations.WindowOp,
69+
left: typed_expr.TypedExpr,
70+
right: typed_expr.TypedExpr,
71+
window: typing.Optional[window_spec.WindowSpec] = None,
72+
) -> sge.Expression:
73+
raise ValueError(f"Can't compile unrecognized operation: {op}")
74+
75+
76+
@functools.singledispatch
77+
def compile_unary_agg(
78+
op: ops.aggregations.WindowOp,
79+
column: typed_expr.TypedExpr,
80+
window: typing.Optional[window_spec.WindowSpec] = None,
81+
) -> sge.Expression:
82+
raise ValueError(f"Can't compile unrecognized operation: {op}")
83+
84+
85+
@functools.singledispatch
86+
def compile_ordered_unary_agg(
87+
op: ops.aggregations.WindowOp,
88+
column: typed_expr.TypedExpr,
89+
window: typing.Optional[window_spec.WindowSpec] = None,
90+
order_by: typing.Sequence[sge.Expression] = [],
91+
) -> sge.Expression:
92+
raise ValueError(f"Can't compile unrecognized operation: {op}")
93+
94+
95+
@compile_unary_agg.register
96+
def _(
97+
op: ops.aggregations.SumOp,
98+
column: typed_expr.TypedExpr,
99+
window: typing.Optional[window_spec.WindowSpec] = None,
100+
) -> sge.Expression:
101+
# Will be null if all inputs are null. Pandas defaults to zero sum though.
102+
expr = _apply_window_if_present(sge.func("SUM", column.expr), window)
103+
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
104+
105+
106+
def _apply_window_if_present(
107+
value: sge.Expression,
108+
window: typing.Optional[window_spec.WindowSpec] = None,
109+
) -> sge.Expression:
110+
if window is not None:
111+
raise NotImplementedError("Can't apply window to the expression.")
112+
return value

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
25+
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
2526
from bigframes.core.compile.sqlglot.expressions import typed_expr
2627
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2728
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -217,7 +218,7 @@ def compile_filter(
217218
self, node: nodes.FilterNode, child: ir.SQLGlotIR
218219
) -> ir.SQLGlotIR:
219220
condition = scalar_compiler.compile_scalar_expression(node.predicate)
220-
return child.filter(condition)
221+
return child.filter(tuple([condition]))
221222

222223
@_compile_node.register
223224
def compile_join(
@@ -267,6 +268,37 @@ def compile_random_sample(
267268
) -> ir.SQLGlotIR:
268269
return child.sample(node.fraction)
269270

271+
@_compile_node.register
272+
def compile_aggregate(
273+
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
274+
) -> ir.SQLGlotIR:
275+
ordering_cols = tuple(
276+
sge.Ordered(
277+
this=scalar_compiler.compile_scalar_expression(
278+
ordering.scalar_expression
279+
),
280+
desc=ordering.direction.is_ascending is False,
281+
nulls_first=ordering.na_last is False,
282+
)
283+
for ordering in node.order_by
284+
)
285+
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
286+
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
287+
for agg, id in node.aggregations
288+
)
289+
by_cols: tuple[sge.Expression, ...] = tuple(
290+
scalar_compiler.compile_scalar_expression(by_col)
291+
for by_col in node.by_column_ids
292+
)
293+
294+
dropna_cols = []
295+
if node.dropna:
296+
for key, by_col in zip(node.by_column_ids, by_cols):
297+
if node.child.field_by_id[key.id].nullable:
298+
dropna_cols.append(by_col)
299+
300+
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
301+
270302

271303
def _replace_unsupported_ops(node: nodes.BigFrameNode):
272304
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import dataclasses
18+
import functools
1819
import typing
1920

2021
from google.cloud import bigquery
@@ -25,11 +26,9 @@
2526
import sqlglot.expressions as sge
2627

2728
from bigframes import dtypes
28-
from bigframes.core import guid, utils
29+
from bigframes.core import guid, local_data, schema, utils
2930
from bigframes.core.compile.sqlglot.expressions import typed_expr
3031
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
31-
import bigframes.core.local_data as local_data
32-
import bigframes.core.schema as bf_schema
3332

3433
# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
3534
try:
@@ -68,7 +67,7 @@ def sql(self) -> str:
6867
def from_pyarrow(
6968
cls,
7069
pa_table: pa.Table,
71-
schema: bf_schema.ArraySchema,
70+
schema: schema.ArraySchema,
7271
uid_gen: guid.SequentialUIDGenerator,
7372
) -> SQLGlotIR:
7473
"""Builds SQLGlot expression from a pyarrow table.
@@ -280,9 +279,13 @@ def limit(
280279

281280
def filter(
282281
self,
283-
condition: sge.Expression,
282+
conditions: tuple[sge.Expression, ...],
284283
) -> SQLGlotIR:
285284
"""Filters the query by adding a WHERE clause."""
285+
condition = _and(conditions)
286+
if condition is None:
287+
return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen)
288+
286289
new_expr = _select_to_cte(
287290
self.expr,
288291
sge.to_identifier(
@@ -316,10 +319,11 @@ def join(
316319
right_ctes = right_select.args.pop("with", [])
317320
merged_ctes = [*left_ctes, *right_ctes]
318321

319-
join_conditions = [
320-
_join_condition(left, right, joins_nulls) for left, right in conditions
321-
]
322-
join_on = sge.And(expressions=join_conditions) if join_conditions else None
322+
join_on = _and(
323+
tuple(
324+
_join_condition(left, right, joins_nulls) for left, right in conditions
325+
)
326+
)
323327

324328
join_type_str = join_type if join_type != "outer" else "full outer"
325329
new_expr = (
@@ -364,6 +368,47 @@ def sample(self, fraction: float) -> SQLGlotIR:
364368
).where(condition, append=False)
365369
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
366370

371+
def aggregate(
372+
self,
373+
aggregations: tuple[tuple[str, sge.Expression], ...],
374+
by_cols: tuple[sge.Expression, ...],
375+
dropna_cols: tuple[sge.Expression, ...],
376+
) -> SQLGlotIR:
377+
"""Applies the aggregation expressions.
378+
379+
Args:
380+
aggregations: output_column_id, aggregation_expr tuples
381+
by_cols: column expressions for aggregation
382+
dropna_cols: columns whether null keys should be dropped
383+
"""
384+
aggregations_expr = [
385+
sge.Alias(
386+
this=expr,
387+
alias=sge.to_identifier(id, quoted=self.quoted),
388+
)
389+
for id, expr in aggregations
390+
]
391+
392+
new_expr = _select_to_cte(
393+
self.expr,
394+
sge.to_identifier(
395+
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
396+
),
397+
)
398+
new_expr = new_expr.group_by(*by_cols).select(
399+
*[*by_cols, *aggregations_expr], append=False
400+
)
401+
402+
condition = _and(
403+
tuple(
404+
sg.not_(sge.Is(this=drop_col, expression=sge.Null()))
405+
for drop_col in dropna_cols
406+
)
407+
)
408+
if condition is not None:
409+
new_expr = new_expr.where(condition, append=False)
410+
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
411+
367412
def insert(
368413
self,
369414
destination: bigquery.TableReference,
@@ -552,6 +597,16 @@ def _table(table: bigquery.TableReference) -> sge.Table:
552597
)
553598

554599

600+
def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]:
601+
"""Chains multiple expressions together using a logical AND."""
602+
if not conditions:
603+
return None
604+
605+
return functools.reduce(
606+
lambda left, right: sge.And(this=left, expression=right), conditions
607+
)
608+
609+
555610
def _join_condition(
556611
left: typed_expr.TypedExpr,
557612
right: typed_expr.TypedExpr,

bigframes/core/rewrite/schema_binding.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import typing
1617

1718
from bigframes.core import bigframe_node
1819
from bigframes.core import expression as ex
@@ -65,4 +66,49 @@ def bind_schema_to_node(
6566
conditions=conditions,
6667
)
6768

69+
if isinstance(node, nodes.AggregateNode):
70+
aggregations = []
71+
for aggregation, id in node.aggregations:
72+
if isinstance(aggregation, ex.UnaryAggregation):
73+
replaced = typing.cast(
74+
ex.Aggregation,
75+
dataclasses.replace(
76+
aggregation,
77+
arg=typing.cast(
78+
ex.RefOrConstant,
79+
ex.bind_schema_fields(
80+
aggregation.arg, node.child.field_by_id
81+
),
82+
),
83+
),
84+
)
85+
aggregations.append((replaced, id))
86+
elif isinstance(aggregation, ex.BinaryAggregation):
87+
replaced = typing.cast(
88+
ex.Aggregation,
89+
dataclasses.replace(
90+
aggregation,
91+
left=typing.cast(
92+
ex.RefOrConstant,
93+
ex.bind_schema_fields(
94+
aggregation.left, node.child.field_by_id
95+
),
96+
),
97+
right=typing.cast(
98+
ex.RefOrConstant,
99+
ex.bind_schema_fields(
100+
aggregation.right, node.child.field_by_id
101+
),
102+
),
103+
),
104+
)
105+
aggregations.append((replaced, id))
106+
else:
107+
aggregations.append((aggregation, id))
108+
109+
return dataclasses.replace(
110+
node,
111+
aggregations=tuple(aggregations),
112+
)
113+
68114
return node
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_too` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_1` AS `bfcol_2`,
10+
`bfcol_0` AS `bfcol_3`
11+
FROM `bfcte_0`
12+
), `bfcte_2` AS (
13+
SELECT
14+
`bfcol_3`,
15+
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
16+
FROM `bfcte_1`
17+
WHERE
18+
NOT `bfcol_3` IS NULL
19+
GROUP BY
20+
`bfcol_3`
21+
)
22+
SELECT
23+
`bfcol_3` AS `bool_col`,
24+
`bfcol_6` AS `int64_too`
25+
FROM `bfcte_2`
26+
ORDER BY
27+
`bfcol_3` ASC NULLS LAST

0 commit comments

Comments
 (0)