Skip to content

Commit 752d332

Browse files
authored
refactor: add compile_window to the sqlglot compiler (#1889)
Fixes internal issue 430350912
1 parent 0bd5e1b commit 752d332

File tree

13 files changed

+428
-41
lines changed

13 files changed

+428
-41
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import sqlglot.expressions as sge
1717

18-
from bigframes.core import expression
18+
from bigframes.core import expression, window_spec
1919
from bigframes.core.compile.sqlglot.aggregations import (
2020
binary_compiler,
2121
nullary_compiler,
@@ -56,3 +56,21 @@ def compile_aggregate(
5656
return binary_compiler.compile(aggregate.op, left, right)
5757
else:
5858
raise ValueError(f"Unexpected aggregation: {aggregate}")
59+
60+
61+
def compile_analytic(
62+
aggregate: expression.Aggregation,
63+
window: window_spec.WindowSpec,
64+
) -> sge.Expression:
65+
if isinstance(aggregate, expression.NullaryAggregation):
66+
return nullary_compiler.compile(aggregate.op)
67+
if isinstance(aggregate, expression.UnaryAggregation):
68+
column = typed_expr.TypedExpr(
69+
scalar_compiler.compile_scalar_expression(aggregate.arg),
70+
aggregate.arg.output_type,
71+
)
72+
return unary_compiler.compile(aggregate.op, column, window)
73+
elif isinstance(aggregate, expression.BinaryAggregation):
74+
raise NotImplementedError("binary analytic operations not yet supported")
75+
else:
76+
raise ValueError(f"Unexpected analytic operation: {aggregate}")

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import sqlglot.expressions as sge
2020

21+
from bigframes import dtypes
2122
from bigframes.core import window_spec
2223
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2324
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
@@ -42,8 +43,11 @@ def _(
4243
column: typed_expr.TypedExpr,
4344
window: typing.Optional[window_spec.WindowSpec] = None,
4445
) -> sge.Expression:
46+
expr = column.expr
47+
if column.dtype == dtypes.BOOL_DTYPE:
48+
expr = sge.Cast(this=column.expr, to="INT64")
4549
# Will be null if all inputs are null. Pandas defaults to zero sum though.
46-
expr = apply_window_if_present(sge.func("SUM", column.expr), window)
50+
expr = apply_window_if_present(sge.func("SUM", expr), window)
4751
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
4852

4953

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,72 @@ def compile_aggregate(
298298

299299
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
300300

301+
@_compile_node.register
302+
def compile_window(
303+
self, node: nodes.WindowOpNode, child: ir.SQLGlotIR
304+
) -> ir.SQLGlotIR:
305+
window_spec = node.window_spec
306+
if node.expression.op.order_independent and window_spec.is_unbounded:
307+
# notably percentile_cont does not support ordering clause
308+
window_spec = window_spec.without_order()
309+
310+
window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
311+
312+
inputs: tuple[sge.Expression, ...] = tuple(
313+
scalar_compiler.compile_scalar_expression(expression.DerefOp(column))
314+
for column in node.expression.column_references
315+
)
316+
317+
clauses: list[tuple[sge.Expression, sge.Expression]] = []
318+
if node.expression.op.skips_nulls and not node.never_skip_nulls:
319+
for column in inputs:
320+
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))
321+
322+
if window_spec.min_periods and len(inputs) > 0:
323+
if node.expression.op.skips_nulls:
324+
# Most operations do not count NULL values towards min_periods
325+
not_null_columns = [
326+
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
327+
for column in inputs
328+
]
329+
# All inputs must be non-null for observation to count
330+
if not not_null_columns:
331+
is_observation_expr: sge.Expression = sge.convert(True)
332+
else:
333+
is_observation_expr = not_null_columns[0]
334+
for expr in not_null_columns[1:]:
335+
is_observation_expr = sge.And(
336+
this=is_observation_expr, expression=expr
337+
)
338+
is_observation = ir._cast(is_observation_expr, "INT64")
339+
else:
340+
# Operations like count treat even NULLs as valid observations
341+
# for the sake of min_periods notnull is just used to convert
342+
# null values to non-null (FALSE) values to be counted.
343+
is_observation = ir._cast(
344+
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
345+
"INT64",
346+
)
347+
348+
observation_count = windows.apply_window_if_present(
349+
sge.func("SUM", is_observation), window_spec
350+
)
351+
clauses.append(
352+
(
353+
observation_count < sge.convert(window_spec.min_periods),
354+
sge.Null(),
355+
)
356+
)
357+
if clauses:
358+
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
359+
window_op = sge.Case(ifs=when_expressions, default=window_op)
360+
361+
# TODO: check if we can directly window the expression.
362+
return child.window(
363+
window_op=window_op,
364+
output_column_id=node.output_name.sql,
365+
)
366+
301367

302368
def _replace_unsupported_ops(node: nodes.BigFrameNode):
303369
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
@functools.singledispatch
3333
def compile_scalar_expression(
34-
expression: expression.Expression,
34+
expr: expression.Expression,
3535
) -> sge.Expression:
3636
"""Compiles BigFrames scalar expression into SQLGlot expression."""
3737
raise ValueError(f"Can't compile unrecognized node: {expression}")

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def aggregate(
409409
new_expr = new_expr.where(condition, append=False)
410410
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
411411

412+
def window(
413+
self,
414+
window_op: sge.Expression,
415+
output_column_id: str,
416+
) -> SQLGlotIR:
417+
return self.project(((output_column_id, window_op),))
418+
412419
def insert(
413420
self,
414421
destination: bigquery.TableReference,

bigframes/core/rewrite/schema_binding.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from bigframes.core import bigframe_node
1919
from bigframes.core import expression as ex
20-
from bigframes.core import nodes
20+
from bigframes.core import nodes, ordering
2121

2222

2323
def bind_schema_to_tree(
@@ -79,46 +79,77 @@ def bind_schema_to_node(
7979
if isinstance(node, nodes.AggregateNode):
8080
aggregations = []
8181
for aggregation, id in node.aggregations:
82-
if isinstance(aggregation, ex.UnaryAggregation):
83-
replaced = typing.cast(
84-
ex.Aggregation,
85-
dataclasses.replace(
86-
aggregation,
87-
arg=typing.cast(
88-
ex.RefOrConstant,
89-
ex.bind_schema_fields(
90-
aggregation.arg, node.child.field_by_id
91-
),
92-
),
93-
),
82+
aggregations.append(
83+
(_bind_schema_to_aggregation_expr(aggregation, node.child), id)
84+
)
85+
86+
return dataclasses.replace(
87+
node,
88+
aggregations=tuple(aggregations),
89+
)
90+
91+
if isinstance(node, nodes.WindowOpNode):
92+
window_spec = dataclasses.replace(
93+
node.window_spec,
94+
grouping_keys=tuple(
95+
typing.cast(
96+
ex.DerefOp, ex.bind_schema_fields(expr, node.child.field_by_id)
9497
)
95-
aggregations.append((replaced, id))
96-
elif isinstance(aggregation, ex.BinaryAggregation):
97-
replaced = typing.cast(
98-
ex.Aggregation,
99-
dataclasses.replace(
100-
aggregation,
101-
left=typing.cast(
102-
ex.RefOrConstant,
103-
ex.bind_schema_fields(
104-
aggregation.left, node.child.field_by_id
105-
),
106-
),
107-
right=typing.cast(
108-
ex.RefOrConstant,
109-
ex.bind_schema_fields(
110-
aggregation.right, node.child.field_by_id
111-
),
112-
),
98+
for expr in node.window_spec.grouping_keys
99+
),
100+
ordering=tuple(
101+
ordering.OrderingExpression(
102+
scalar_expression=ex.bind_schema_fields(
103+
expr.scalar_expression, node.child.field_by_id
113104
),
105+
direction=expr.direction,
106+
na_last=expr.na_last,
114107
)
115-
aggregations.append((replaced, id))
116-
else:
117-
aggregations.append((aggregation, id))
118-
108+
for expr in node.window_spec.ordering
109+
),
110+
)
119111
return dataclasses.replace(
120112
node,
121-
aggregations=tuple(aggregations),
113+
expression=_bind_schema_to_aggregation_expr(node.expression, node.child),
114+
window_spec=window_spec,
122115
)
123116

124117
return node
118+
119+
120+
def _bind_schema_to_aggregation_expr(
121+
aggregation: ex.Aggregation,
122+
child: bigframe_node.BigFrameNode,
123+
) -> ex.Aggregation:
124+
assert isinstance(
125+
aggregation, ex.Aggregation
126+
), f"Expected Aggregation, got {type(aggregation)}"
127+
128+
if isinstance(aggregation, ex.UnaryAggregation):
129+
return typing.cast(
130+
ex.Aggregation,
131+
dataclasses.replace(
132+
aggregation,
133+
arg=typing.cast(
134+
ex.RefOrConstant,
135+
ex.bind_schema_fields(aggregation.arg, child.field_by_id),
136+
),
137+
),
138+
)
139+
elif isinstance(aggregation, ex.BinaryAggregation):
140+
return typing.cast(
141+
ex.Aggregation,
142+
dataclasses.replace(
143+
aggregation,
144+
left=typing.cast(
145+
ex.RefOrConstant,
146+
ex.bind_schema_fields(aggregation.left, child.field_by_id),
147+
),
148+
right=typing.cast(
149+
ex.RefOrConstant,
150+
ex.bind_schema_fields(aggregation.right, child.field_by_id),
151+
),
152+
),
153+
)
154+
else:
155+
return aggregation

bigframes/operations/aggregations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def skips_nulls(self):
517517

518518
@dataclasses.dataclass(frozen=True)
519519
class DiffOp(UnaryWindowOp):
520+
name: ClassVar[str] = "diff"
520521
periods: int
521522

522523
@property

tests/system/small/engines/test_windowing.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.cloud import bigquery
1516
import pytest
1617

17-
from bigframes.core import array_value
18-
from bigframes.session import polars_executor
18+
from bigframes.core import array_value, expression, identifiers, nodes, window_spec
19+
import bigframes.operations.aggregations as agg_ops
20+
from bigframes.session import direct_gbq_execution, polars_executor
1921
from bigframes.testing.engine_utils import assert_equivalence_execution
2022

2123
pytest.importorskip("polars")
@@ -31,3 +33,28 @@ def test_engines_with_offsets(
3133
):
3234
result, _ = scalars_array_value.promote_offsets()
3335
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
36+
37+
38+
def test_engines_with_rows_window(
39+
scalars_array_value: array_value.ArrayValue,
40+
bigquery_client: bigquery.Client,
41+
):
42+
window = window_spec.WindowSpec(
43+
bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"),
44+
)
45+
window_node = nodes.WindowOpNode(
46+
child=scalars_array_value.node,
47+
expression=expression.UnaryAggregation(
48+
agg_ops.sum_op, expression.deref("int64_too")
49+
),
50+
window_spec=window,
51+
output_name=identifiers.ColumnId("sum_int64"),
52+
never_skip_nulls=False,
53+
skip_reproject_unsafe=False,
54+
)
55+
56+
bq_executor = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
57+
bq_sqlgot_executor = direct_gbq_execution.DirectGbqExecutor(
58+
bigquery_client, compiler="sqlglot"
59+
)
60+
assert_equivalence_execution(window_node, bq_executor, bq_sqlgot_executor)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_0` AS `bfcol_7`,
12+
`bfcol_1` AS `bfcol_8`,
13+
`bfcol_0` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*
18+
FROM `bfcte_1`
19+
WHERE
20+
NOT `bfcol_9` IS NULL
21+
), `bfcte_3` AS (
22+
SELECT
23+
*,
24+
CASE
25+
WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
26+
PARTITION BY `bfcol_9`
27+
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
28+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
29+
) < 3
30+
THEN NULL
31+
ELSE COALESCE(
32+
SUM(CAST(`bfcol_7` AS INT64)) OVER (
33+
PARTITION BY `bfcol_9`
34+
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
35+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
36+
),
37+
0
38+
)
39+
END AS `bfcol_15`
40+
FROM `bfcte_2`
41+
), `bfcte_4` AS (
42+
SELECT
43+
*
44+
FROM `bfcte_3`
45+
WHERE
46+
NOT `bfcol_9` IS NULL
47+
), `bfcte_5` AS (
48+
SELECT
49+
*,
50+
CASE
51+
WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
52+
PARTITION BY `bfcol_9`
53+
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
54+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
55+
) < 3
56+
THEN NULL
57+
ELSE COALESCE(
58+
SUM(`bfcol_8`) OVER (
59+
PARTITION BY `bfcol_9`
60+
ORDER BY `bfcol_9` IS NULL ASC NULLS LAST, `bfcol_9` ASC NULLS LAST, `bfcol_2` IS NULL ASC NULLS LAST, `bfcol_2` ASC NULLS LAST
61+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
62+
),
63+
0
64+
)
65+
END AS `bfcol_21`
66+
FROM `bfcte_4`
67+
)
68+
SELECT
69+
`bfcol_9` AS `bool_col`,
70+
`bfcol_6` AS `rowindex`,
71+
`bfcol_15` AS `bool_col_1`,
72+
`bfcol_21` AS `int64_col`
73+
FROM `bfcte_5`
74+
ORDER BY
75+
`bfcol_9` ASC NULLS LAST,
76+
`bfcol_2` ASC NULLS LAST

0 commit comments

Comments
 (0)