Skip to content

Commit 8f115e7

Browse files
authored
refactor: support to compile project and add_op (#1677)
1 parent 1a658b2 commit 8f115e7

File tree

6 files changed

+116
-2
lines changed

6 files changed

+116
-2
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ def compile_selection(
163163
)
164164
return child.select(selected_cols)
165165

166+
@_compile_node.register
167+
def compile_projection(
168+
self, node: nodes.ProjectionNode, child: ir.SQLGlotIR
169+
) -> ir.SQLGlotIR:
170+
projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
171+
(id.sql, scalar_compiler.compile_scalar_expression(expr))
172+
for expr, id in node.assignments
173+
)
174+
return child.project(projected_cols)
175+
166176

167177
def _replace_unsupported_ops(node: nodes.BigFrameNode):
168178
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import sqlglot.expressions as sge
1919

2020
from bigframes.core import expression
21+
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
22+
import bigframes.operations as ops
2123

2224

2325
@functools.singledispatch
@@ -29,5 +31,47 @@ def compile_scalar_expression(
2931

3032

3133
@compile_scalar_expression.register
32-
def compile_deref_op(expr: expression.DerefOp):
34+
def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
3335
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
36+
37+
38+
@compile_scalar_expression.register
39+
def compile_constant_expression(
40+
expr: expression.ScalarConstantExpression,
41+
) -> sge.Expression:
42+
return ir._literal(expr.value, expr.dtype)
43+
44+
45+
@compile_scalar_expression.register
46+
def compile_op_expression(expr: expression.OpExpression):
47+
# Non-recursively compiles the children scalar expressions.
48+
args = tuple(map(compile_scalar_expression, expr.inputs))
49+
50+
op = expr.op
51+
op_name = expr.op.__class__.__name__
52+
method_name = f"compile_{op_name.lower()}"
53+
method = globals().get(method_name, None)
54+
if method is None:
55+
raise ValueError(
56+
f"Compilation method '{method_name}' not found for operator '{op_name}'."
57+
)
58+
59+
if isinstance(op, ops.UnaryOp):
60+
return method(op, args[0])
61+
elif isinstance(op, ops.BinaryOp):
62+
return method(op, args[0], args[1])
63+
elif isinstance(op, ops.TernaryOp):
64+
return method(op, args[0], args[1], args[2])
65+
elif isinstance(op, ops.NaryOp):
66+
return method(op, *args)
67+
else:
68+
raise TypeError(
69+
f"Operator '{op_name}' has an unrecognized arity or type "
70+
"and cannot be compiled."
71+
)
72+
73+
74+
# TODO: add parenthesize for operators
75+
def compile_addop(op: ops.AddOp, left: sge.Expression, right: sge.Expression):
76+
# TODO: support addop for string dtype.
77+
return sge.Add(this=left, expression=right)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,21 @@ def select(
118118
new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False)
119119
return SQLGlotIR(expr=new_expr)
120120

121+
def project(
122+
self,
123+
projected_cols: tuple[tuple[str, sge.Expression], ...],
124+
) -> SQLGlotIR:
125+
projected_cols_expr = [
126+
sge.Alias(
127+
this=expr,
128+
alias=sge.to_identifier(id, quoted=self.quoted),
129+
)
130+
for id, expr in projected_cols
131+
]
132+
# TODO: some columns are not able to be projected into the same select.
133+
select_expr = self.expr.select(*projected_cols_expr, append=True)
134+
return SQLGlotIR(expr=select_expr)
135+
121136
def _encapsulate_as_cte(
122137
self,
123138
) -> sge.Select:

bigframes/dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,9 @@ def __getitem__(
609609
def _getitem_label(self, key: blocks.Label):
610610
col_ids = self._block.cols_matching_label(key)
611611
if len(col_ids) == 0:
612-
raise KeyError(key)
612+
raise KeyError(
613+
f"{key} not found in DataFrame columns: {self._block.column_labels}"
614+
)
613615
block = self._block.select_columns(col_ids)
614616
if isinstance(self.columns, pandas.MultiIndex):
615617
# Multiindex should drop-level if not selecting entire
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
*,
4+
`bfcol_0` AS `bfcol_3`,
5+
`bfcol_1` + 1 AS `bfcol_4`
6+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` INT64, `bfcol_2` INT64>>[STRUCT(0, 123456789, 0), STRUCT(1, -987654321, 1), STRUCT(2, 314159, 2), STRUCT(3, CAST(NULL AS INT64), 3), STRUCT(4, -234892, 4), STRUCT(5, 55555, 5), STRUCT(6, 101202303, 6), STRUCT(7, -214748367, 7), STRUCT(8, 2, 8)])
7+
)
8+
SELECT
9+
`bfcol_3` AS `bfcol_5`,
10+
`bfcol_4` AS `bfcol_6`,
11+
`bfcol_2` AS `bfcol_7`
12+
FROM `bfcte_0`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pytest
17+
18+
import bigframes
19+
import bigframes.pandas as bpd
20+
21+
pytest.importorskip("pytest_snapshot")
22+
23+
24+
def test_compile_projection(
25+
scalars_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot
26+
):
27+
bf_df = bpd.DataFrame(
28+
scalars_types_pandas_df[["int64_col"]], session=compiler_session
29+
)
30+
bf_df["int64_col"] = bf_df["int64_col"] + 1
31+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)