Skip to content

Commit 5746c1c

Browse files
Merge branch 'main' into window_kurt
2 parents 92fa353 + 7d152d3 commit 5746c1c

File tree

20 files changed

+403
-158
lines changed

20 files changed

+403
-158
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
23+
from bigframes.core.compile.sqlglot.expressions.string_ops import (
24+
string_index,
25+
string_slice,
26+
)
2327
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2428
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2529
import bigframes.dtypes as dtypes
@@ -30,9 +34,12 @@
3034

3135
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
3236
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
37+
if expr.dtype == dtypes.STRING_DTYPE:
38+
return string_index(expr, op.index)
39+
3340
return sge.Bracket(
3441
this=expr.expr,
35-
expressions=[sge.Literal.number(op.index)],
42+
expressions=[sge.convert(op.index)],
3643
safe=True,
3744
offset=False,
3845
)
@@ -68,15 +75,45 @@ def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
6875

6976
@register_unary_op(ops.ArraySliceOp, pass_op=True)
7077
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
71-
slice_idx = sg.to_identifier("slice_idx")
78+
if expr.dtype == dtypes.STRING_DTYPE:
79+
return string_slice(expr, op.start, op.stop)
80+
else:
81+
return _array_slice(expr, op)
7282

73-
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
7483

75-
if op.stop is not None:
76-
conditions.append(slice_idx < op.stop)
84+
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
85+
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
86+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
7787

88+
89+
@register_nary_op(ops.ToArrayOp)
90+
def _(*exprs: TypedExpr) -> sge.Expression:
91+
do_upcast_bool = any(
92+
dtypes.is_numeric(expr.dtype, include_bool=False) for expr in exprs
93+
)
94+
if do_upcast_bool:
95+
sg_exprs = [_coerce_bool_to_int(expr) for expr in exprs]
96+
else:
97+
sg_exprs = [expr.expr for expr in exprs]
98+
return sge.Array(expressions=sg_exprs)
99+
100+
101+
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
102+
"""Coerce boolean expression to integer."""
103+
if typed_expr.dtype == dtypes.BOOL_DTYPE:
104+
return sge.Cast(this=typed_expr.expr, to="INT64")
105+
return typed_expr.expr
106+
107+
108+
def _string_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
78109
# local name for each element in the array
79110
el = sg.to_identifier("el")
111+
# local name for the index in the array
112+
slice_idx = sg.to_identifier("slice_idx")
113+
114+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
115+
if op.stop is not None:
116+
conditions.append(slice_idx < op.stop)
80117

81118
selected_elements = (
82119
sge.select(el)
@@ -93,25 +130,26 @@ def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
93130
return sge.array(selected_elements)
94131

95132

96-
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
97-
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
98-
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
133+
def _array_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
134+
# local name for each element in the array
135+
el = sg.to_identifier("el")
136+
# local name for the index in the array
137+
slice_idx = sg.to_identifier("slice_idx")
99138

139+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
140+
if op.stop is not None:
141+
conditions.append(slice_idx < op.stop)
100142

101-
@register_nary_op(ops.ToArrayOp)
102-
def _(*exprs: TypedExpr) -> sge.Expression:
103-
do_upcast_bool = any(
104-
dtypes.is_numeric(expr.dtype, include_bool=False) for expr in exprs
143+
selected_elements = (
144+
sge.select(el)
145+
.from_(
146+
sge.Unnest(
147+
expressions=[expr.expr],
148+
alias=sge.TableAlias(columns=[el]),
149+
offset=slice_idx,
150+
)
151+
)
152+
.where(*conditions)
105153
)
106-
if do_upcast_bool:
107-
sg_exprs = [_coerce_bool_to_int(expr) for expr in exprs]
108-
else:
109-
sg_exprs = [expr.expr for expr in exprs]
110-
return sge.Array(expressions=sg_exprs)
111-
112154

113-
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
114-
"""Coerce boolean expression to integer."""
115-
if typed_expr.dtype == dtypes.BOOL_DTYPE:
116-
return sge.Cast(this=typed_expr.expr, to="INT64")
117-
return typed_expr.expr
155+
return sge.array(selected_elements)

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 123 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import functools
18+
import typing
1819

1920
import sqlglot.expressions as sge
2021

@@ -29,7 +30,7 @@
2930

3031
@register_unary_op(ops.capitalize_op)
3132
def _(expr: TypedExpr) -> sge.Expression:
32-
return sge.Initcap(this=expr.expr)
33+
return sge.Initcap(this=expr.expr, expression=sge.convert(""))
3334

3435

3536
@register_unary_op(ops.StrContainsOp, pass_op=True)
@@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:
4445

4546
@register_unary_op(ops.StrExtractOp, pass_op=True)
4647
def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression:
47-
return sge.RegexpExtract(
48-
this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n)
49-
)
48+
# Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
49+
# capturing group.
50+
pat_expr = sge.convert(op.pat)
51+
if op.n != 0:
52+
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
53+
else:
54+
pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*"))
55+
56+
rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1"))
57+
rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat))
58+
return sge.If(this=rex_contains, true=rex_replace, false=sge.null())
5059

5160

5261
@register_unary_op(ops.StrFindOp, pass_op=True)
@@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:
7584

7685
@register_unary_op(ops.StrLstripOp, pass_op=True)
7786
def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression:
78-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
87+
return sge.func("LTRIM", expr.expr, sge.convert(op.to_strip))
88+
89+
90+
@register_unary_op(ops.StrRstripOp, pass_op=True)
91+
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
92+
return sge.func("RTRIM", expr.expr, sge.convert(op.to_strip))
7993

8094

8195
@register_unary_op(ops.StrPadOp, pass_op=True)
8296
def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression:
83-
pad_length = sge.func(
84-
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
85-
)
97+
expr_length = sge.Length(this=expr.expr)
98+
fillchar = sge.convert(op.fillchar)
99+
pad_length = sge.func("GREATEST", expr_length, sge.convert(op.length))
100+
86101
if op.side == "left":
87-
return sge.func(
88-
"LPAD",
89-
expr.expr,
90-
pad_length,
91-
sge.convert(op.fillchar),
92-
)
102+
return sge.func("LPAD", expr.expr, pad_length, fillchar)
93103
elif op.side == "right":
94-
return sge.func(
95-
"RPAD",
96-
expr.expr,
97-
pad_length,
98-
sge.convert(op.fillchar),
99-
)
104+
return sge.func("RPAD", expr.expr, pad_length, fillchar)
100105
else: # side == both
101-
lpad_amount = sge.Cast(
102-
this=sge.func(
103-
"SAFE_DIVIDE",
104-
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
105-
sge.convert(2),
106-
),
107-
to="INT64",
108-
) + sge.Length(this=expr.expr)
106+
lpad_amount = (
107+
sge.Cast(
108+
this=sge.Floor(
109+
this=sge.func(
110+
"SAFE_DIVIDE",
111+
sge.Sub(this=pad_length, expression=expr_length),
112+
sge.convert(2),
113+
)
114+
),
115+
to="INT64",
116+
)
117+
+ expr_length
118+
)
109119
return sge.func(
110120
"RPAD",
111-
sge.func(
112-
"LPAD",
113-
expr.expr,
114-
lpad_amount,
115-
sge.convert(op.fillchar),
116-
),
121+
sge.func("LPAD", expr.expr, lpad_amount, fillchar),
117122
pad_length,
118-
sge.convert(op.fillchar),
123+
fillchar,
119124
)
120125

121126

@@ -148,12 +153,15 @@ def _(expr: TypedExpr) -> sge.Expression:
148153

149154
@register_unary_op(ops.isdecimal_op)
150155
def _(expr: TypedExpr) -> sge.Expression:
151-
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\d+$"))
156+
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^(\p{Nd})+$"))
152157

153158

154159
@register_unary_op(ops.isdigit_op)
155160
def _(expr: TypedExpr) -> sge.Expression:
156-
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\p{Nd}+$"))
161+
regexp_pattern = (
162+
r"^[\p{Nd}\x{00B9}\x{00B2}\x{00B3}\x{2070}\x{2074}-\x{2079}\x{2080}-\x{2089}]+$"
163+
)
164+
return sge.RegexpLike(this=expr.expr, expression=sge.convert(regexp_pattern))
157165

158166

159167
@register_unary_op(ops.islower_op)
@@ -224,11 +232,6 @@ def _(expr: TypedExpr) -> sge.Expression:
224232
return sge.func("REVERSE", expr.expr)
225233

226234

227-
@register_unary_op(ops.StrRstripOp, pass_op=True)
228-
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
229-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT")
230-
231-
232235
@register_unary_op(ops.StartsWithOp, pass_op=True)
233236
def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression:
234237
if not op.pat:
@@ -253,27 +256,12 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253256

254257
@register_unary_op(ops.StrGetOp, pass_op=True)
255258
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
256-
return sge.Substring(
257-
this=expr.expr,
258-
start=sge.convert(op.i + 1),
259-
length=sge.convert(1),
260-
)
259+
return string_index(expr, op.i)
261260

262261

263262
@register_unary_op(ops.StrSliceOp, pass_op=True)
264263
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
265-
start = op.start + 1 if op.start is not None else None
266-
if op.end is None:
267-
length = None
268-
elif op.start is None:
269-
length = op.end
270-
else:
271-
length = op.end - op.start
272-
return sge.Substring(
273-
this=expr.expr,
274-
start=sge.convert(start) if start is not None else None,
275-
length=sge.convert(length) if length is not None else None,
276-
)
264+
return string_slice(expr, op.start, op.end)
277265

278266

279267
@register_unary_op(ops.upper_op)
@@ -314,3 +302,79 @@ def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
314302
],
315303
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
316304
)
305+
306+
307+
def string_index(expr: TypedExpr, index: int) -> sge.Expression:
308+
sub_str = sge.Substring(
309+
this=expr.expr,
310+
start=sge.convert(index + 1),
311+
length=sge.convert(1),
312+
)
313+
return sge.If(
314+
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
315+
true=sub_str,
316+
false=sge.Null(),
317+
)
318+
319+
320+
def string_slice(
321+
expr: TypedExpr, op_start: typing.Optional[int], op_end: typing.Optional[int]
322+
) -> sge.Expression:
323+
column_length = sge.Length(this=expr.expr)
324+
if op_start is None:
325+
start = 0
326+
else:
327+
start = op_start
328+
329+
start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
330+
length_expr: typing.Optional[sge.Expression]
331+
if op_end is None:
332+
length_expr = None
333+
elif op_end < 0:
334+
if start < 0:
335+
start_expr = sge.Greatest(
336+
expressions=[
337+
sge.convert(1),
338+
column_length + sge.convert(start + 1),
339+
]
340+
)
341+
length_expr = sge.Greatest(
342+
expressions=[
343+
sge.convert(0),
344+
column_length + sge.convert(op_end),
345+
]
346+
) - sge.Greatest(
347+
expressions=[
348+
sge.convert(0),
349+
column_length + sge.convert(start),
350+
]
351+
)
352+
else:
353+
length_expr = sge.Greatest(
354+
expressions=[
355+
sge.convert(0),
356+
column_length + sge.convert(op_end - start),
357+
]
358+
)
359+
else: # op.end >= 0
360+
if start < 0:
361+
start_expr = sge.Greatest(
362+
expressions=[
363+
sge.convert(1),
364+
column_length + sge.convert(start + 1),
365+
]
366+
)
367+
length_expr = sge.convert(op_end) - sge.Greatest(
368+
expressions=[
369+
sge.convert(0),
370+
column_length + sge.convert(start),
371+
]
372+
)
373+
else:
374+
length_expr = sge.convert(op_end - start)
375+
376+
return sge.Substring(
377+
this=expr.expr,
378+
start=start_expr,
379+
length=length_expr,
380+
)

0 commit comments

Comments
 (0)