Skip to content

Commit 0709f17

Browse files
authored
refactor: provide infrastructure for SQLGlot scalar compiler (#1850)
* refactor: provide infrastructure for SQLGlot scalar compiler * remove redundant code * remove redundant code * add TODO back
1 parent c706759 commit 0709f17

File tree

8 files changed

+244
-41
lines changed

8 files changed

+244
-41
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import dtypes
22+
from bigframes import operations as ops
23+
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
24+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
25+
26+
BinaryOpCompiler = typing.Callable[[ops.BinaryOp, TypedExpr, TypedExpr], sge.Expression]
27+
28+
BINARY_OP_REIGSTRATION = OpRegistration[BinaryOpCompiler]()
29+
30+
31+
def compile(op: ops.BinaryOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
32+
return BINARY_OP_REIGSTRATION[op](op, left, right)
33+
34+
35+
# TODO: add parenthesize for operators
36+
@BINARY_OP_REIGSTRATION.register(ops.add_op)
37+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
38+
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
39+
# String addition
40+
return sge.Concat(expressions=[left.expr, right.expr])
41+
42+
# Numerical addition
43+
return sge.Add(this=left.expr, expression=right.expr)
44+
45+
46+
@BINARY_OP_REIGSTRATION.register(ops.ge_op)
47+
def compile_ge(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
48+
49+
return sge.GTE(this=left.expr, expression=right.expr)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import operations as ops
22+
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
23+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
24+
25+
# No simpler way to specify that the compilation function expects varargs.
26+
NaryOpCompiler = typing.Callable[..., sge.Expression]
27+
28+
NARY_OP_REIGSTRATION = OpRegistration[NaryOpCompiler]()
29+
30+
31+
def compile(op: ops.NaryOp, *args: TypedExpr) -> sge.Expression:
32+
return NARY_OP_REIGSTRATION[op](op, *args)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
from typing import Generic, TypeVar
19+
20+
from bigframes import operations as ops
21+
22+
T = TypeVar("T")
23+
24+
25+
class OpRegistration(Generic[T]):
26+
_registered_ops: dict[str, T] = {}
27+
28+
def register(
29+
self, op: ops.ScalarOp | type[ops.ScalarOp]
30+
) -> typing.Callable[[T], T]:
31+
key = typing.cast(str, op.name)
32+
33+
def decorator(item: T):
34+
if key in self._registered_ops:
35+
raise ValueError(f"{key} is already registered")
36+
self._registered_ops[key] = item
37+
return item
38+
39+
return decorator
40+
41+
def __getitem__(self, key: str | ops.ScalarOp) -> T:
42+
if isinstance(key, ops.ScalarOp):
43+
return self._registered_ops[key.name]
44+
return self._registered_ops[key]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import operations as ops
22+
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
23+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
24+
25+
TernaryOpCompiler = typing.Callable[
26+
[ops.TernaryOp, TypedExpr, TypedExpr, TypedExpr], sge.Expression
27+
]
28+
29+
TERNATRY_OP_REIGSTRATION = OpRegistration[TernaryOpCompiler]()
30+
31+
32+
def compile(
33+
op: ops.TernaryOp, expr1: TypedExpr, expr2: TypedExpr, expr3: TypedExpr
34+
) -> sge.Expression:
35+
return TERNATRY_OP_REIGSTRATION[op](op, expr1, expr2, expr3)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
17+
import sqlglot.expressions as sge
18+
19+
from bigframes import dtypes
20+
21+
22+
@dataclasses.dataclass(frozen=True)
23+
class TypedExpr:
24+
"""SQLGlot expression with type."""
25+
26+
expr: sge.Expression
27+
dtype: dtypes.ExpressionType
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
19+
import sqlglot.expressions as sge
20+
21+
from bigframes import operations as ops
22+
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
23+
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
24+
25+
UnaryOpCompiler = typing.Callable[[ops.UnaryOp, TypedExpr], sge.Expression]
26+
27+
UNARY_OP_REIGSTRATION = OpRegistration[UnaryOpCompiler]()
28+
29+
30+
def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
31+
return UNARY_OP_REIGSTRATION[op](op, expr)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,22 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import dataclasses
1716
import functools
1817

1918
import sqlglot.expressions as sge
2019

21-
from bigframes import dtypes
2220
from bigframes.core import expression
21+
from bigframes.core.compile.sqlglot.expressions import (
22+
binary_compiler,
23+
nary_compiler,
24+
ternary_compiler,
25+
typed_expr,
26+
unary_compiler,
27+
)
2328
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2429
import bigframes.operations as ops
2530

2631

27-
@dataclasses.dataclass(frozen=True)
28-
class TypedExpr:
29-
"""SQLGlot expression with type."""
30-
31-
expr: sge.Expression
32-
dtype: dtypes.ExpressionType
33-
34-
3532
@functools.singledispatch
3633
def compile_scalar_expression(
3734
expression: expression.Expression,
@@ -63,46 +60,21 @@ def compile_constant_expression(
6360
def compile_op_expression(expr: expression.OpExpression) -> sge.Expression:
6461
# Non-recursively compiles the children scalar expressions.
6562
args = tuple(
66-
TypedExpr(compile_scalar_expression(input), input.output_type)
63+
typed_expr.TypedExpr(compile_scalar_expression(input), input.output_type)
6764
for input in expr.inputs
6865
)
6966

7067
op = expr.op
71-
op_name = expr.op.__class__.__name__
72-
method_name = f"compile_{op_name.lower()}"
73-
method = globals().get(method_name, None)
74-
if method is None:
75-
raise ValueError(
76-
f"Compilation method '{method_name}' not found for operator '{op_name}'."
77-
)
78-
7968
if isinstance(op, ops.UnaryOp):
80-
return method(op, args[0])
69+
return unary_compiler.compile(op, args[0])
8170
elif isinstance(op, ops.BinaryOp):
82-
return method(op, args[0], args[1])
71+
return binary_compiler.compile(op, args[0], args[1])
8372
elif isinstance(op, ops.TernaryOp):
84-
return method(op, args[0], args[1], args[2])
73+
return ternary_compiler.compile(op, args[0], args[1], args[2])
8574
elif isinstance(op, ops.NaryOp):
86-
return method(op, *args)
75+
return nary_compiler.compile(op, *args)
8776
else:
8877
raise TypeError(
89-
f"Operator '{op_name}' has an unrecognized arity or type "
78+
f"Operator '{op.name}' has an unrecognized arity or type "
9079
"and cannot be compiled."
9180
)
92-
93-
94-
# TODO: add parenthesize for operators
95-
def compile_addop(op: ops.AddOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
96-
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
97-
# String addition
98-
return sge.Concat(expressions=[left.expr, right.expr])
99-
100-
# Numerical addition
101-
return sge.Add(this=left.expr, expression=right.expr)
102-
103-
104-
def compile_ge(
105-
op: ops.ge_op, left: TypedExpr, right: TypedExpr # type: ignore[valid-type]
106-
) -> sge.Expression:
107-
108-
return sge.GTE(this=left.expr, expression=right.expr)

0 commit comments

Comments
 (0)