Skip to content

Commit c88a825

Browse files
authored
chore: add array operators to SQLGlot compiler (#1852)
* [WIP] Add array operators. Need to finish tests * add tests * fix lint * fix typos * Use sge.Bracket() for safe_offset
1 parent bc885bd commit c88a825

File tree

18 files changed

+287
-42
lines changed

18 files changed

+287
-42
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,22 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import dtypes
2220
from bigframes import operations as ops
2321
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2422
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2523

26-
BinaryOpCompiler = typing.Callable[[ops.BinaryOp, TypedExpr, TypedExpr], sge.Expression]
27-
28-
BINARY_OP_REIGSTRATION = OpRegistration[BinaryOpCompiler]()
24+
BINARY_OP_REGISTRATION = OpRegistration()
2925

3026

3127
def compile(op: ops.BinaryOp, left: TypedExpr, right: TypedExpr) -> sge.Expression:
32-
return BINARY_OP_REIGSTRATION[op](op, left, right)
28+
return BINARY_OP_REGISTRATION[op](op, left, right)
3329

3430

3531
# TODO: add parenthesize for operators
36-
@BINARY_OP_REIGSTRATION.register(ops.add_op)
32+
@BINARY_OP_REGISTRATION.register(ops.add_op)
3733
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3834
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
3935
# String addition
@@ -43,7 +39,6 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
4339
return sge.Add(this=left.expr, expression=right.expr)
4440

4541

46-
@BINARY_OP_REIGSTRATION.register(ops.ge_op)
47-
def compile_ge(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
48-
42+
@BINARY_OP_REGISTRATION.register(ops.ge_op)
43+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
4944
return sge.GTE(this=left.expr, expression=right.expr)

bigframes/core/compile/sqlglot/expressions/nary_compiler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import operations as ops
2220
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2321
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2422

25-
# No simpler way to specify that the compilation function expects varargs.
26-
NaryOpCompiler = typing.Callable[..., sge.Expression]
27-
28-
NARY_OP_REIGSTRATION = OpRegistration[NaryOpCompiler]()
23+
NARY_OP_REGISTRATION = OpRegistration()
2924

3025

3126
def compile(op: ops.NaryOp, *args: TypedExpr) -> sge.Expression:
32-
return NARY_OP_REIGSTRATION[op](op, *args)
27+
return NARY_OP_REGISTRATION[op](op, *args)

bigframes/core/compile/sqlglot/expressions/op_registration.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,40 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Generic, TypeVar
18+
19+
from sqlglot import expressions as sge
1920

2021
from bigframes import operations as ops
2122

22-
T = TypeVar("T")
23+
# We should've been more specific about input types. Unfortunately,
24+
# MyPy doesn't support more rigorous checks.
25+
CompilationFunc = typing.Callable[..., sge.Expression]
2326

2427

25-
class OpRegistration(Generic[T]):
26-
_registered_ops: dict[str, T] = {}
28+
class OpRegistration:
29+
def __init__(self) -> None:
30+
self._registered_ops: dict[str, CompilationFunc] = {}
2731

2832
def register(
2933
self, op: ops.ScalarOp | type[ops.ScalarOp]
30-
) -> typing.Callable[[T], T]:
31-
key = typing.cast(str, op.name)
32-
33-
def decorator(item: T):
34+
) -> typing.Callable[[CompilationFunc], CompilationFunc]:
35+
def decorator(item: CompilationFunc):
36+
def arg_checker(*args, **kwargs):
37+
if not isinstance(args[0], ops.ScalarOp):
38+
raise ValueError(
39+
f"The first parameter must be an operator. Got {type(args[0])}"
40+
)
41+
return item(*args, **kwargs)
42+
43+
key = typing.cast(str, op.name)
3444
if key in self._registered_ops:
3545
raise ValueError(f"{key} is already registered")
3646
self._registered_ops[key] = item
37-
return item
47+
return arg_checker
3848

3949
return decorator
4050

41-
def __getitem__(self, key: str | ops.ScalarOp) -> T:
51+
def __getitem__(self, key: str | ops.ScalarOp) -> CompilationFunc:
4252
if isinstance(key, ops.ScalarOp):
4353
return self._registered_ops[key.name]
4454
return self._registered_ops[key]

bigframes/core/compile/sqlglot/expressions/ternary_compiler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,16 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
1917
import sqlglot.expressions as sge
2018

2119
from bigframes import operations as ops
2220
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2321
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2422

25-
TernaryOpCompiler = typing.Callable[
26-
[ops.TernaryOp, TypedExpr, TypedExpr, TypedExpr], sge.Expression
27-
]
28-
29-
TERNATRY_OP_REIGSTRATION = OpRegistration[TernaryOpCompiler]()
23+
TERNATRY_OP_REGISTRATION = OpRegistration()
3024

3125

3226
def compile(
3327
op: ops.TernaryOp, expr1: TypedExpr, expr2: TypedExpr, expr3: TypedExpr
3428
) -> sge.Expression:
35-
return TERNATRY_OP_REIGSTRATION[op](op, expr1, expr2, expr3)
29+
return TERNATRY_OP_REGISTRATION[op](op, expr1, expr2, expr3)

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,57 @@
1616

1717
import typing
1818

19+
import sqlglot
1920
import sqlglot.expressions as sge
2021

2122
from bigframes import operations as ops
2223
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2324
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2425

25-
UnaryOpCompiler = typing.Callable[[ops.UnaryOp, TypedExpr], sge.Expression]
26-
27-
UNARY_OP_REIGSTRATION = OpRegistration[UnaryOpCompiler]()
26+
UNARY_OP_REGISTRATION = OpRegistration()
2827

2928

3029
def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
31-
return UNARY_OP_REIGSTRATION[op](op, expr)
30+
return UNARY_OP_REGISTRATION[op](op, expr)
31+
32+
33+
@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp)
34+
def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression:
35+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
36+
37+
38+
@UNARY_OP_REGISTRATION.register(ops.ArrayIndexOp)
39+
def _(op: ops.ArrayIndexOp, expr: TypedExpr) -> sge.Expression:
40+
return sge.Bracket(
41+
this=expr.expr,
42+
expressions=[sge.Literal.number(op.index)],
43+
safe=True,
44+
offset=False,
45+
)
46+
47+
48+
@UNARY_OP_REGISTRATION.register(ops.ArraySliceOp)
49+
def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression:
50+
slice_idx = sqlglot.to_identifier("slice_idx")
51+
52+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
53+
54+
if op.stop is not None:
55+
conditions.append(slice_idx < op.stop)
56+
57+
# local name for each element in the array
58+
el = sqlglot.to_identifier("el")
59+
60+
selected_elements = (
61+
sge.select(el)
62+
.from_(
63+
sge.Unnest(
64+
expressions=[expr.expr],
65+
alias=sge.TableAlias(columns=[el]),
66+
offset=slice_idx,
67+
)
68+
)
69+
.where(*conditions)
70+
)
71+
72+
return sge.array(selected_elements)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`rowindex` AS `bfcol_0`,
4+
`string_list_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`repeated_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`bfcol_1`[SAFE_OFFSET(1)] AS `bfcol_4`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_0` AS `rowindex`,
14+
`bfcol_4` AS `string_list_col`
15+
FROM `bfcte_1`

0 commit comments

Comments
 (0)