Skip to content

Commit 9ac8135

Browse files
authored
refactor: implements compile_selection method (#1672)
1 parent 3388191 commit 9ac8135

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
from google.cloud import bigquery
2222
import pyarrow as pa
23+
import sqlglot.expressions as sge
2324

2425
from bigframes.core import expression, identifiers, nodes, rewrite
2526
from bigframes.core.compile import configs
27+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2628
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2729
import bigframes.core.ordering as bf_ordering
2830

@@ -38,7 +40,7 @@ def compile(
3840
ordered: bool = True,
3941
limit: typing.Optional[int] = None,
4042
) -> str:
41-
"""Compile node into sql where rows are sorted with ORDER BY."""
43+
"""Compiles node into sql where rows are sorted with ORDER BY."""
4244
request = configs.CompileRequest(node, sort_rows=ordered, peek_count=limit)
4345
return self._compile_sql(request).sql
4446

@@ -48,7 +50,7 @@ def compile_raw(
4850
) -> typing.Tuple[
4951
str, typing.Sequence[bigquery.SchemaField], bf_ordering.RowOrdering
5052
]:
51-
"""Compile node into sql that exposes all columns, including hidden
53+
"""Compiles node into sql that exposes all columns, including hidden
5254
ordering-only columns."""
5355
request = configs.CompileRequest(
5456
node, sort_rows=False, materialize_all_order_keys=True
@@ -163,6 +165,9 @@ def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
163165

164166

165167
@_compile_node.register
166-
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR):
167-
# TODO: add support for selection
168-
return child
168+
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
169+
select_cols: typing.Dict[str, sge.Expression] = {
170+
id.name: scalar_compiler.compile_scalar_expression(expr)
171+
for expr, id in node.input_output_pairs
172+
}
173+
return child.select(select_cols)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import functools
17+
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core import expression
21+
22+
23+
@functools.singledispatch
24+
def compile_scalar_expression(
25+
expression: expression.Expression,
26+
) -> sge.Expression:
27+
"""Compiles BigFrames scalar expression into SQLGlot expression."""
28+
raise ValueError(f"Can't compile unrecognized node: {expression}")
29+
30+
31+
@compile_scalar_expression.register
32+
def compile_deref_op(expr: expression.DerefOp):
33+
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
class SQLGlotIR:
3131
"""Helper class to build SQLGlot Query and generate SQL string."""
3232

33-
expr: sge.Expression = sge.Expression()
33+
expr: sge.Select = sg.select()
3434
"""The SQLGlot expression representing the query."""
3535

3636
dialect = sqlglot.dialects.bigquery.BigQuery
@@ -90,6 +90,20 @@ def from_pandas(
9090
)
9191
return cls(expr=sg.select(sge.Star()).from_(expr))
9292

93+
def select(
94+
self,
95+
select_cols: typing.Dict[str, sge.Expression],
96+
) -> SQLGlotIR:
97+
selected_cols = [
98+
sge.Alias(
99+
this=expr,
100+
alias=sge.to_identifier(id, quoted=self.quoted),
101+
)
102+
for id, expr in select_cols.items()
103+
]
104+
expr = self.expr.select(*selected_cols, append=False)
105+
return SQLGlotIR(expr=expr)
106+
93107

94108
def _literal(value: typing.Any, dtype: str) -> sge.Expression:
95109
if value is None:
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
SELECT
2-
*
2+
`bfcol_0` AS `bfcol_5`,
3+
`bfcol_1` AS `bfcol_6`,
4+
`bfcol_2` AS `bfcol_7`,
5+
`bfcol_3` AS `bfcol_8`,
6+
`bfcol_4` AS `bfcol_9`
37
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` INT64, `bfcol_2` BOOLEAN, `bfcol_3` STRING, `bfcol_4` INT64>>[(1, -10, TRUE, 'b', 0), (2, 20, CAST(NULL AS BOOLEAN), 'aa', 1), (3, 30, FALSE, 'ccc', 2)])

0 commit comments

Comments
 (0)