Skip to content

Commit 33ab2b8

Browse files
authored
chore: inject dtypes to SQLGlot scalar expr compiler (#1821)
* chore: inject dtypes to SQLGlot scalar expr compiler * fix format
1 parent 3abc02e commit 33ab2b8

File tree

8 files changed

+140
-48
lines changed

8 files changed

+140
-48
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,16 @@ def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode:
119119
return typing.cast(nodes.ResultNode, result_node)
120120

121121
def _compile_result_node(self, root: nodes.ResultNode) -> str:
122-
sqlglot_ir = self.compile_node(root.child)
123-
122+
# Have to bind schema as the final step before compilation.
123+
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
124124
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
125125
(name, scalar_compiler.compile_scalar_expression(ref))
126126
for ref, name in root.output_cols
127127
)
128128
# Skip squashing selections to ensure the right ordering and limit keys
129-
sqlglot_ir = sqlglot_ir.select(selected_cols, squash_selections=False)
129+
sqlglot_ir = self.compile_node(root.child).select(
130+
selected_cols, squash_selections=False
131+
)
130132

131133
if root.order_by is not None:
132134
ordering_cols = tuple(
@@ -220,6 +222,5 @@ def compile_concat(
220222

221223
def _replace_unsupported_ops(node: nodes.BigFrameNode):
222224
node = nodes.bottom_up(node, rewrite.rewrite_slice)
223-
node = nodes.bottom_up(node, schema_binding.bind_schema_to_expressions)
224225
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
225226
return node

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,25 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import dataclasses
1617
import functools
1718

1819
import sqlglot.expressions as sge
1920

21+
from bigframes import dtypes
2022
from bigframes.core import expression
2123
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2224
import bigframes.operations as ops
2325

2426

27+
@dataclasses.dataclass(frozen=True)
28+
class TypedExpr:
29+
"""SQLGlot expression with type."""
30+
31+
expr: sge.Expression
32+
dtype: dtypes.ExpressionType
33+
34+
2535
@functools.singledispatch
2636
def compile_scalar_expression(
2737
expression: expression.Expression,
@@ -50,9 +60,12 @@ def compile_constant_expression(
5060

5161

5262
@compile_scalar_expression.register
53-
def compile_op_expression(expr: expression.OpExpression):
63+
def compile_op_expression(expr: expression.OpExpression) -> sge.Expression:
5464
# Non-recursively compiles the children scalar expressions.
55-
args = tuple(map(compile_scalar_expression, expr.inputs))
65+
args = tuple(
66+
TypedExpr(compile_scalar_expression(input), input.output_type)
67+
for input in expr.inputs
68+
)
5669

5770
op = expr.op
5871
op_name = expr.op.__class__.__name__
@@ -79,8 +92,10 @@ def compile_op_expression(expr: expression.OpExpression):
7992

8093

8194
# TODO: add parenthesize for operators
82-
def compile_addop(
83-
op: ops.AddOp, left: sge.Expression, right: sge.Expression
84-
) -> sge.Expression:
85-
# TODO: support addop for string dtype.
86-
return sge.Add(this=left, expression=right)
95+
def compile_addop(op: ops.AddOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
96+
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
97+
# String addition
98+
return sge.Concat(expressions=[left.expr, right.expr])
99+
100+
# Numerical addition
101+
return sge.Add(this=left.expr, expression=right.expr)

bigframes/core/rewrite/schema_binding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from bigframes.core import nodes
2020

2121

22-
def bind_schema_to_expressions(
22+
def bind_schema_to_tree(
23+
node: bigframe_node.BigFrameNode,
24+
) -> bigframe_node.BigFrameNode:
25+
return nodes.bottom_up(node, bind_schema_to_node)
26+
27+
28+
def bind_schema_to_node(
2329
node: bigframe_node.BigFrameNode,
2430
) -> bigframe_node.BigFrameNode:
2531
if isinstance(node, nodes.ProjectionNode):

tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_limit/out.sql

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,14 @@ WITH `bfcte_0` AS (
66
`float64_col` AS `bfcol_3`,
77
`bool_col` AS `bfcol_4`
88
FROM `test-project`.`test_dataset`.`test_table`
9-
), `bfcte_1` AS (
10-
SELECT
11-
*,
12-
`bfcol_1` AS `bfcol_5`
13-
FROM `bfcte_0`
149
)
1510
SELECT
1611
`bfcol_0` AS `rowindex`,
1712
`bfcol_1` AS `int64_col`,
1813
`bfcol_2` AS `string_col`,
1914
`bfcol_3` AS `float64_col`,
2015
`bfcol_4` AS `bool_col`
21-
FROM `bfcte_1`
16+
FROM `bfcte_0`
2217
ORDER BY
23-
`bfcol_5` ASC NULLS LAST
18+
`bfcol_1` ASC NULLS LAST
2419
LIMIT 10

tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_ordering/out.sql

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,13 @@ WITH `bfcte_0` AS (
66
`float64_col` AS `bfcol_3`,
77
`bool_col` AS `bfcol_4`
88
FROM `test-project`.`test_dataset`.`test_table`
9-
), `bfcte_1` AS (
10-
SELECT
11-
`bfcol_0` AS `bfcol_5`,
12-
`bfcol_1` AS `bfcol_6`,
13-
`bfcol_2` AS `bfcol_7`,
14-
`bfcol_3` AS `bfcol_8`,
15-
`bfcol_4` AS `bfcol_9`
16-
FROM `bfcte_0`
17-
), `bfcte_2` AS (
18-
SELECT
19-
*,
20-
`bfcol_5` AS `bfcol_10`
21-
FROM `bfcte_1`
22-
), `bfcte_3` AS (
23-
SELECT
24-
`bfcol_5` AS `bfcol_11`,
25-
`bfcol_6` AS `bfcol_12`,
26-
`bfcol_7` AS `bfcol_13`,
27-
`bfcol_8` AS `bfcol_14`,
28-
`bfcol_9` AS `bfcol_15`,
29-
`bfcol_10` AS `bfcol_16`
30-
FROM `bfcte_2`
319
)
3210
SELECT
33-
`bfcol_11` AS `rowindex`,
34-
`bfcol_12` AS `int64_col`,
35-
`bfcol_13` AS `string_col`,
36-
`bfcol_14` AS `float64_col`,
37-
`bfcol_15` AS `bool_col`
38-
FROM `bfcte_3`
11+
`bfcol_0` AS `rowindex`,
12+
`bfcol_1` AS `int64_col`,
13+
`bfcol_2` AS `string_col`,
14+
`bfcol_3` AS `float64_col`,
15+
`bfcol_4` AS `bool_col`
16+
FROM `bfcte_0`
3917
ORDER BY
40-
`bfcol_16` ASC NULLS LAST
18+
`bfcol_0` ASC NULLS LAST
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`string_col` AS `bfcol_2`,
6+
`float64_col` AS `bfcol_3`,
7+
`bool_col` AS `bfcol_4`
8+
FROM `test-project`.`test_dataset`.`test_table`
9+
), `bfcte_1` AS (
10+
SELECT
11+
*,
12+
`bfcol_0` AS `bfcol_5`,
13+
`bfcol_2` AS `bfcol_6`,
14+
`bfcol_3` AS `bfcol_7`,
15+
`bfcol_4` AS `bfcol_8`,
16+
`bfcol_1` + `bfcol_1` AS `bfcol_9`
17+
FROM `bfcte_0`
18+
), `bfcte_2` AS (
19+
SELECT
20+
`bfcol_5` AS `bfcol_10`,
21+
`bfcol_9` AS `bfcol_11`,
22+
`bfcol_6` AS `bfcol_12`,
23+
`bfcol_7` AS `bfcol_13`,
24+
`bfcol_8` AS `bfcol_14`
25+
FROM `bfcte_1`
26+
)
27+
SELECT
28+
`bfcol_10` AS `rowindex`,
29+
`bfcol_11` AS `int64_col`,
30+
`bfcol_12` AS `string_col`,
31+
`bfcol_13` AS `float64_col`,
32+
`bfcol_14` AS `bool_col`
33+
FROM `bfcte_2`
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`string_col` AS `bfcol_2`,
6+
`float64_col` AS `bfcol_3`,
7+
`bool_col` AS `bfcol_4`
8+
FROM `test-project`.`test_dataset`.`test_table`
9+
), `bfcte_1` AS (
10+
SELECT
11+
*,
12+
`bfcol_0` AS `bfcol_5`,
13+
`bfcol_1` AS `bfcol_6`,
14+
`bfcol_3` AS `bfcol_7`,
15+
`bfcol_4` AS `bfcol_8`,
16+
CONCAT(`bfcol_2`, 'a') AS `bfcol_9`
17+
FROM `bfcte_0`
18+
), `bfcte_2` AS (
19+
SELECT
20+
`bfcol_5` AS `bfcol_10`,
21+
`bfcol_6` AS `bfcol_11`,
22+
`bfcol_9` AS `bfcol_12`,
23+
`bfcol_7` AS `bfcol_13`,
24+
`bfcol_8` AS `bfcol_14`
25+
FROM `bfcte_1`
26+
)
27+
SELECT
28+
`bfcol_10` AS `rowindex`,
29+
`bfcol_11` AS `int64_col`,
30+
`bfcol_12` AS `string_col`,
31+
`bfcol_13` AS `float64_col`,
32+
`bfcol_14` AS `bool_col`
33+
FROM `bfcte_2`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes
18+
19+
pytest.importorskip("pytest_snapshot")
20+
21+
22+
def test_compile_numerical_add(compiler_session: bigframes.Session, snapshot):
23+
bf_df = compiler_session.read_gbq_table("test-project.test_dataset.test_table")
24+
bf_df["int64_col"] = bf_df["int64_col"] + bf_df["int64_col"]
25+
snapshot.assert_match(bf_df.sql, "out.sql")
26+
27+
28+
def test_compile_string_add(compiler_session: bigframes.Session, snapshot):
29+
bf_df = compiler_session.read_gbq_table("test-project.test_dataset.test_table")
30+
bf_df["string_col"] = bf_df["string_col"] + "a"
31+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)