Skip to content

Commit 4e3923d

Browse files
committed
refactor: fix some string ops in the sqlglot compiler (part 2)
1 parent f1ff345 commit 4e3923d

File tree

8 files changed

+126
-56
lines changed

8 files changed

+126
-56
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@
3030

3131
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
3232
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
33+
if expr.dtype == dtypes.STRING_DTYPE:
34+
return _string_index(expr, op)
35+
3336
return sge.Bracket(
3437
this=expr.expr,
35-
expressions=[sge.Literal.number(op.index)],
38+
expressions=[sge.convert(op.index)],
3639
safe=True,
3740
offset=False,
3841
)
@@ -115,3 +118,16 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
115118
if typed_expr.dtype == dtypes.BOOL_DTYPE:
116119
return sge.Cast(this=typed_expr.expr, to="INT64")
117120
return typed_expr.expr
121+
122+
123+
def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
124+
sub_str = sge.Substring(
125+
this=expr.expr,
126+
start=sge.convert(op.index + 1),
127+
length=sge.convert(1),
128+
)
129+
return sge.If(
130+
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
131+
true=sub_str,
132+
false=sge.Null(),
133+
)

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
@register_unary_op(ops.capitalize_op)
3131
def _(expr: TypedExpr) -> sge.Expression:
32-
return sge.Initcap(this=expr.expr)
32+
return sge.Initcap(this=expr.expr, expression=sge.convert(""))
3333

3434

3535
@register_unary_op(ops.StrContainsOp, pass_op=True)
@@ -44,9 +44,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:
4444

4545
@register_unary_op(ops.StrExtractOp, pass_op=True)
4646
def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression:
47-
return sge.RegexpExtract(
48-
this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n)
49-
)
47+
# Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
48+
# capturing group.
49+
pat_expr = sge.convert(op.pat)
50+
if op.n != 0:
51+
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
52+
else:
53+
pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*"))
54+
55+
rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1"))
56+
rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat))
57+
return sge.If(this=rex_contains, true=rex_replace, false=sge.null())
5058

5159

5260
@register_unary_op(ops.StrFindOp, pass_op=True)
@@ -75,47 +83,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:
7583

7684
@register_unary_op(ops.StrLstripOp, pass_op=True)
7785
def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression:
78-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
86+
return sge.func("LTRIM", expr.expr, sge.convert(op.to_strip))
87+
88+
89+
@register_unary_op(ops.StrRstripOp, pass_op=True)
90+
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
91+
return sge.func("RTRIM", expr.expr, sge.convert(op.to_strip))
7992

8093

8194
@register_unary_op(ops.StrPadOp, pass_op=True)
8295
def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression:
83-
pad_length = sge.func(
84-
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
85-
)
96+
expr_length = sge.Length(this=expr.expr)
97+
fillchar = sge.convert(op.fillchar)
98+
pad_length = sge.func("GREATEST", expr_length, sge.convert(op.length))
99+
86100
if op.side == "left":
87-
return sge.func(
88-
"LPAD",
89-
expr.expr,
90-
pad_length,
91-
sge.convert(op.fillchar),
92-
)
101+
return sge.func("LPAD", expr.expr, pad_length, fillchar)
93102
elif op.side == "right":
94-
return sge.func(
95-
"RPAD",
96-
expr.expr,
97-
pad_length,
98-
sge.convert(op.fillchar),
99-
)
103+
return sge.func("RPAD", expr.expr, pad_length, fillchar)
100104
else: # side == both
101-
lpad_amount = sge.Cast(
102-
this=sge.func(
103-
"SAFE_DIVIDE",
104-
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
105-
sge.convert(2),
106-
),
107-
to="INT64",
108-
) + sge.Length(this=expr.expr)
105+
lpad_amount = (
106+
sge.Cast(
107+
this=sge.Floor(
108+
this=sge.func(
109+
"SAFE_DIVIDE",
110+
sge.Sub(this=pad_length, expression=expr_length),
111+
sge.convert(2),
112+
)
113+
),
114+
to="INT64",
115+
)
116+
+ expr_length
117+
)
109118
return sge.func(
110119
"RPAD",
111-
sge.func(
112-
"LPAD",
113-
expr.expr,
114-
lpad_amount,
115-
sge.convert(op.fillchar),
116-
),
120+
sge.func("LPAD", expr.expr, lpad_amount, fillchar),
117121
pad_length,
118-
sge.convert(op.fillchar),
122+
fillchar,
119123
)
120124

121125

@@ -224,11 +228,6 @@ def _(expr: TypedExpr) -> sge.Expression:
224228
return sge.func("REVERSE", expr.expr)
225229

226230

227-
@register_unary_op(ops.StrRstripOp, pass_op=True)
228-
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
229-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT")
230-
231-
232231
@register_unary_op(ops.StartsWithOp, pass_op=True)
233232
def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression:
234233
if not op.pat:
@@ -253,26 +252,77 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253252

254253
@register_unary_op(ops.StrGetOp, pass_op=True)
255254
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
256-
return sge.Substring(
255+
sub_str = sge.Substring(
257256
this=expr.expr,
258257
start=sge.convert(op.i + 1),
259258
length=sge.convert(1),
260259
)
261260

261+
return sge.If(
262+
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
263+
true=sub_str,
264+
false=sge.Null(),
265+
)
266+
262267

263268
@register_unary_op(ops.StrSliceOp, pass_op=True)
264269
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
265-
start = op.start + 1 if op.start is not None else None
266-
if op.end is None:
267-
length = None
268-
elif op.start is None:
269-
length = op.end
270+
column_length = sge.Length(this=expr.expr)
271+
if op.start is None:
272+
start = 0
270273
else:
271-
length = op.end - op.start
274+
start = op.start
275+
276+
start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
277+
if op.end is None:
278+
length_expr = None
279+
elif op.end < 0:
280+
if start < 0:
281+
start_expr = sge.Greatest(
282+
expressions=[
283+
sge.convert(1),
284+
column_length + sge.convert(start + 1),
285+
]
286+
)
287+
length_expr = sge.Greatest(
288+
expressions=[
289+
sge.convert(0),
290+
column_length + sge.convert(op.end),
291+
]
292+
) - sge.Greatest(
293+
expressions=[
294+
sge.convert(0),
295+
column_length + sge.convert(start),
296+
]
297+
)
298+
else:
299+
length_expr = sge.Greatest(
300+
expressions=[
301+
sge.convert(0),
302+
column_length + sge.convert(op.end - start),
303+
]
304+
)
305+
else: # op.end >= 0
306+
if start < 0:
307+
start_expr = sge.Greatest(
308+
expressions=[
309+
sge.convert(1),
310+
column_length + sge.convert(start + 1),
311+
]
312+
)
313+
length_expr = sge.convert(op.end) - sge.Greatest(
314+
expressions=[
315+
sge.convert(0),
316+
column_length + sge.convert(start),
317+
]
318+
)
319+
else:
320+
length_expr = sge.convert(op.end - start)
321+
272322
return sge.Substring(
273323
this=expr.expr,
274-
start=sge.convert(start) if start is not None else None,
275-
length=sge.convert(length) if length is not None else None,
324+
start=start_expr,
325+
length=length_expr,
276326
)
277327

278328

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
INITCAP(`string_col`) AS `bfcol_1`
8+
INITCAP(`string_col`, '') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
TRIM(`string_col`, ' ') AS `bfcol_1`
8+
LTRIM(`string_col`, ' ') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
TRIM(`string_col`, ' ') AS `bfcol_1`
8+
RTRIM(`string_col`, ' ') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
REGEXP_EXTRACT(`string_col`, '([a-z]*)') AS `bfcol_1`
8+
IF(
9+
REGEXP_CONTAINS(`string_col`, '([a-z]*)'),
10+
REGEXP_REPLACE(`string_col`, CONCAT('.*?', '([a-z]*)', '.*'), '\\1'),
11+
NULL
12+
) AS `bfcol_1`
913
FROM `bfcte_0`
1014
)
1115
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
SUBSTRING(`string_col`, 2, 1) AS `bfcol_1`
8+
IF(SUBSTRING(`string_col`, 2, 1) <> '', SUBSTRING(`string_col`, 2, 1), NULL) AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
1010
RPAD(
1111
LPAD(
1212
`string_col`,
13-
CAST(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2) AS INT64) + LENGTH(`string_col`),
13+
CAST(FLOOR(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2)) AS INT64) + LENGTH(`string_col`),
1414
'-'
1515
),
1616
GREATEST(LENGTH(`string_col`), 10),

0 commit comments

Comments
 (0)