Skip to content

Commit dabac32

Browse files
authored
refactor: reorganize the sqlglot scalar compiler layout - part 1 (#2075)
1 parent 5ce5d63 commit dabac32

File tree

13 files changed

+636
-479
lines changed

13 files changed

+636
-479
lines changed

bigframes/core/compile/sqlglot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414
from __future__ import annotations
1515

1616
from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler
17+
import bigframes.core.compile.sqlglot.expressions.binary_compiler # noqa: F401
18+
import bigframes.core.compile.sqlglot.expressions.unary_compiler # noqa: F401
1719

1820
__all__ = ["SQLGlotCompiler"]

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def compile_aggregate(
3535
return nullary_compiler.compile(aggregate.op)
3636
if isinstance(aggregate, agg_expressions.UnaryAggregation):
3737
column = typed_expr.TypedExpr(
38-
scalar_compiler.compile_scalar_expression(aggregate.arg),
38+
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),
3939
aggregate.arg.output_type,
4040
)
4141
if not aggregate.op.order_independent:
@@ -46,11 +46,11 @@ def compile_aggregate(
4646
return unary_compiler.compile(aggregate.op, column)
4747
elif isinstance(aggregate, agg_expressions.BinaryAggregation):
4848
left = typed_expr.TypedExpr(
49-
scalar_compiler.compile_scalar_expression(aggregate.left),
49+
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.left),
5050
aggregate.left.output_type,
5151
)
5252
right = typed_expr.TypedExpr(
53-
scalar_compiler.compile_scalar_expression(aggregate.right),
53+
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.right),
5454
aggregate.right.output_type,
5555
)
5656
return binary_compiler.compile(aggregate.op, left, right)
@@ -66,7 +66,7 @@ def compile_analytic(
6666
return nullary_compiler.compile(aggregate.op)
6767
if isinstance(aggregate, agg_expressions.UnaryAggregation):
6868
column = typed_expr.TypedExpr(
69-
scalar_compiler.compile_scalar_expression(aggregate.arg),
69+
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),
7070
aggregate.arg.output_type,
7171
)
7272
return unary_compiler.compile(aggregate.op, column, window)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def apply_window_if_present(
5151
order = sge.Order(expressions=order_by) if order_by else None
5252

5353
group_by = (
54-
[scalar_compiler.compile_scalar_expression(key) for key in window.grouping_keys]
54+
[
55+
scalar_compiler.scalar_op_compiler.compile_expression(key)
56+
for key in window.grouping_keys
57+
]
5558
if window.grouping_keys
5659
else None
5760
)
@@ -101,7 +104,7 @@ def get_window_order_by(
101104

102105
order_by = []
103106
for ordering_spec_item in ordering:
104-
expr = scalar_compiler.compile_scalar_expression(
107+
expr = scalar_compiler.scalar_op_compiler.compile_expression(
105108
ordering_spec_item.scalar_expression
106109
)
107110
desc = not ordering_spec_item.direction.is_ascending

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,15 @@ def _compile_result_node(self, root: nodes.ResultNode) -> str:
131131
# Have to bind schema as the final step before compilation.
132132
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
133133
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
134-
(name, scalar_compiler.compile_scalar_expression(ref))
134+
(name, scalar_compiler.scalar_op_compiler.compile_expression(ref))
135135
for ref, name in root.output_cols
136136
)
137137
sqlglot_ir = self.compile_node(root.child).select(selected_cols)
138138

139139
if root.order_by is not None:
140140
ordering_cols = tuple(
141141
sge.Ordered(
142-
this=scalar_compiler.compile_scalar_expression(
142+
this=scalar_compiler.scalar_op_compiler.compile_expression(
143143
ordering.scalar_expression
144144
),
145145
desc=ordering.direction.is_ascending is False,
@@ -199,7 +199,7 @@ def compile_selection(
199199
self, node: nodes.SelectionNode, child: ir.SQLGlotIR
200200
) -> ir.SQLGlotIR:
201201
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
202-
(id.sql, scalar_compiler.compile_scalar_expression(expr))
202+
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
203203
for expr, id in node.input_output_pairs
204204
)
205205
return child.select(selected_cols)
@@ -209,7 +209,7 @@ def compile_projection(
209209
self, node: nodes.ProjectionNode, child: ir.SQLGlotIR
210210
) -> ir.SQLGlotIR:
211211
projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
212-
(id.sql, scalar_compiler.compile_scalar_expression(expr))
212+
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
213213
for expr, id in node.assignments
214214
)
215215
return child.project(projected_cols)
@@ -218,7 +218,9 @@ def compile_projection(
218218
def compile_filter(
219219
self, node: nodes.FilterNode, child: ir.SQLGlotIR
220220
) -> ir.SQLGlotIR:
221-
condition = scalar_compiler.compile_scalar_expression(node.predicate)
221+
condition = scalar_compiler.scalar_op_compiler.compile_expression(
222+
node.predicate
223+
)
222224
return child.filter(tuple([condition]))
223225

224226
@_compile_node.register
@@ -228,10 +230,12 @@ def compile_join(
228230
conditions = tuple(
229231
(
230232
typed_expr.TypedExpr(
231-
scalar_compiler.compile_scalar_expression(left), left.output_type
233+
scalar_compiler.scalar_op_compiler.compile_expression(left),
234+
left.output_type,
232235
),
233236
typed_expr.TypedExpr(
234-
scalar_compiler.compile_scalar_expression(right), right.output_type
237+
scalar_compiler.scalar_op_compiler.compile_expression(right),
238+
right.output_type,
235239
),
236240
)
237241
for left, right in node.conditions
@@ -250,11 +254,11 @@ def compile_isin_join(
250254
) -> ir.SQLGlotIR:
251255
conditions = (
252256
typed_expr.TypedExpr(
253-
scalar_compiler.compile_scalar_expression(node.left_col),
257+
scalar_compiler.scalar_op_compiler.compile_expression(node.left_col),
254258
node.left_col.output_type,
255259
),
256260
typed_expr.TypedExpr(
257-
scalar_compiler.compile_scalar_expression(node.right_col),
261+
scalar_compiler.scalar_op_compiler.compile_expression(node.right_col),
258262
node.right_col.output_type,
259263
),
260264
)
@@ -308,7 +312,7 @@ def compile_aggregate(
308312
for agg, id in node.aggregations
309313
)
310314
by_cols: tuple[sge.Expression, ...] = tuple(
311-
scalar_compiler.compile_scalar_expression(by_col)
315+
scalar_compiler.scalar_op_compiler.compile_expression(by_col)
312316
for by_col in node.by_column_ids
313317
)
314318

@@ -332,7 +336,9 @@ def compile_window(
332336
window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
333337

334338
inputs: tuple[sge.Expression, ...] = tuple(
335-
scalar_compiler.compile_scalar_expression(expression.DerefOp(column))
339+
scalar_compiler.scalar_op_compiler.compile_expression(
340+
expression.DerefOp(column)
341+
)
336342
for column in node.expression.column_references
337343
)
338344

bigframes/core/compile/sqlglot/expressions/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
"""Expression implementations for the SQLGlot-based compiler.
16+
17+
This directory structure should reflect the same layout as the
18+
`bigframes/operations` directory where the expressions are defined.
19+
20+
Prefer a few ops per file to keep file sizes manageable for text editors and LLMs.
21+
"""

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,16 @@
2020
from bigframes import dtypes
2121
from bigframes import operations as ops
2222
import bigframes.core.compile.sqlglot.expressions.constants as constants
23-
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2423
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
24+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2525

26-
BINARY_OP_REGISTRATION = OpRegistration()
26+
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op
2727

28-
29-
def compile(op: ops.BinaryOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
30-
return BINARY_OP_REGISTRATION[op](op, left, right)
28+
# TODO: add parenthesize for operators
3129

3230

33-
# TODO: add parenthesize for operators
34-
@BINARY_OP_REGISTRATION.register(ops.add_op)
35-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
31+
@register_binary_op(ops.add_op)
32+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
3633
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
3734
# String addition
3835
return sge.Concat(expressions=[left.expr, right.expr])
@@ -66,15 +63,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
6663
)
6764

6865

69-
@BINARY_OP_REGISTRATION.register(ops.eq_op)
70-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
66+
@register_binary_op(ops.eq_op)
67+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
7168
left_expr = _coerce_bool_to_int(left)
7269
right_expr = _coerce_bool_to_int(right)
7370
return sge.EQ(this=left_expr, expression=right_expr)
7471

7572

76-
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
77-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
73+
@register_binary_op(ops.eq_null_match_op)
74+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
7875
left_expr = left.expr
7976
if right.dtype != dtypes.BOOL_DTYPE:
8077
left_expr = _coerce_bool_to_int(left)
@@ -93,8 +90,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
9390
return sge.EQ(this=left_coalesce, expression=right_coalesce)
9491

9592

96-
@BINARY_OP_REGISTRATION.register(ops.div_op)
97-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
93+
@register_binary_op(ops.div_op)
94+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
9895
left_expr = _coerce_bool_to_int(left)
9996
right_expr = _coerce_bool_to_int(right)
10097

@@ -105,8 +102,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
105102
return result
106103

107104

108-
@BINARY_OP_REGISTRATION.register(ops.floordiv_op)
109-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
105+
@register_binary_op(ops.floordiv_op)
106+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
110107
left_expr = _coerce_bool_to_int(left)
111108
right_expr = _coerce_bool_to_int(right)
112109

@@ -138,41 +135,41 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
138135
return result
139136

140137

141-
@BINARY_OP_REGISTRATION.register(ops.ge_op)
142-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
138+
@register_binary_op(ops.ge_op)
139+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
143140
left_expr = _coerce_bool_to_int(left)
144141
right_expr = _coerce_bool_to_int(right)
145142
return sge.GTE(this=left_expr, expression=right_expr)
146143

147144

148-
@BINARY_OP_REGISTRATION.register(ops.gt_op)
149-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
145+
@register_binary_op(ops.gt_op)
146+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
150147
left_expr = _coerce_bool_to_int(left)
151148
right_expr = _coerce_bool_to_int(right)
152149
return sge.GT(this=left_expr, expression=right_expr)
153150

154151

155-
@BINARY_OP_REGISTRATION.register(ops.JSONSet)
156-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
152+
@register_binary_op(ops.JSONSet, pass_op=True)
153+
def _(left: TypedExpr, right: TypedExpr, op) -> sge.Expression:
157154
return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr)
158155

159156

160-
@BINARY_OP_REGISTRATION.register(ops.lt_op)
161-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
157+
@register_binary_op(ops.lt_op)
158+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
162159
left_expr = _coerce_bool_to_int(left)
163160
right_expr = _coerce_bool_to_int(right)
164161
return sge.LT(this=left_expr, expression=right_expr)
165162

166163

167-
@BINARY_OP_REGISTRATION.register(ops.le_op)
168-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
164+
@register_binary_op(ops.le_op)
165+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
169166
left_expr = _coerce_bool_to_int(left)
170167
right_expr = _coerce_bool_to_int(right)
171168
return sge.LTE(this=left_expr, expression=right_expr)
172169

173170

174-
@BINARY_OP_REGISTRATION.register(ops.mul_op)
175-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
171+
@register_binary_op(ops.mul_op)
172+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
176173
left_expr = _coerce_bool_to_int(left)
177174
right_expr = _coerce_bool_to_int(right)
178175

@@ -186,20 +183,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
186183
return result
187184

188185

189-
@BINARY_OP_REGISTRATION.register(ops.ne_op)
190-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
186+
@register_binary_op(ops.ne_op)
187+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
191188
left_expr = _coerce_bool_to_int(left)
192189
right_expr = _coerce_bool_to_int(right)
193190
return sge.NEQ(this=left_expr, expression=right_expr)
194191

195192

196-
@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op)
197-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
193+
@register_binary_op(ops.obj_make_ref_op)
194+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
198195
return sge.func("OBJ.MAKE_REF", left.expr, right.expr)
199196

200197

201-
@BINARY_OP_REGISTRATION.register(ops.sub_op)
202-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
198+
@register_binary_op(ops.sub_op)
199+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
203200
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
204201
left_expr = _coerce_bool_to_int(left)
205202
right_expr = _coerce_bool_to_int(right)

bigframes/core/compile/sqlglot/expressions/nary_compiler.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

bigframes/core/compile/sqlglot/expressions/op_registration.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)