Skip to content

Commit 32a6c74

Browse files
committed
chore: implement StrPadOp and StrFindOp in compilers
1 parent 26df6e6 commit 32a6c74

File tree

9 files changed

+197
-0
lines changed

9 files changed

+197
-0
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,30 @@ def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression:
182182
return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat))
183183

184184

185+
@UNARY_OP_REGISTRATION.register(ops.StrFindOp)
186+
def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression:
187+
# INSTR is 1-based, so we need to adjust the start position.
188+
start = sge.convert(op.start + 1) if op.start is not None else sge.convert(1)
189+
if op.end is not None:
190+
# BigQuery's INSTR doesn't support `end`, so we need to use SUBSTR.
191+
return sge.func(
192+
"INSTR",
193+
sge.Substring(
194+
this=expr.expr,
195+
start=start,
196+
length=sge.convert(op.end - (op.start or 0)),
197+
),
198+
sge.convert(op.substr),
199+
) - sge.convert(1)
200+
else:
201+
return sge.func(
202+
"INSTR",
203+
expr.expr,
204+
sge.convert(op.substr),
205+
start,
206+
) - sge.convert(1)
207+
208+
185209
@UNARY_OP_REGISTRATION.register(ops.StrContainsOp)
186210
def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression:
187211
return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%"))
@@ -449,6 +473,47 @@ def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression:
449473
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
450474

451475

476+
@UNARY_OP_REGISTRATION.register(ops.StrPadOp)
477+
def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression:
478+
pad_length = sge.func(
479+
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
480+
)
481+
if op.side == "left":
482+
return sge.func(
483+
"LPAD",
484+
expr.expr,
485+
pad_length,
486+
sge.convert(op.fillchar),
487+
)
488+
elif op.side == "right":
489+
return sge.func(
490+
"RPAD",
491+
expr.expr,
492+
pad_length,
493+
sge.convert(op.fillchar),
494+
)
495+
else: # side == both
496+
lpad_amount = sge.Cast(
497+
this=sge.func(
498+
"SAFE_DIVIDE",
499+
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
500+
sge.convert(2),
501+
),
502+
to="INT64",
503+
) + sge.Length(this=expr.expr)
504+
return sge.func(
505+
"RPAD",
506+
sge.func(
507+
"LPAD",
508+
expr.expr,
509+
lpad_amount,
510+
sge.convert(op.fillchar),
511+
),
512+
pad_length,
513+
sge.convert(op.fillchar),
514+
)
515+
516+
452517
@UNARY_OP_REGISTRATION.register(ops.neg_op)
453518
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
454519
return sge.Neg(this=expr.expr)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
RPAD(
9+
LPAD(
10+
`bfcol_0`,
11+
CAST(SAFE_DIVIDE(GREATEST(LENGTH(`bfcol_0`), 10) - LENGTH(`bfcol_0`), 2) AS INT64) + LENGTH(`bfcol_0`),
12+
'-'
13+
),
14+
GREATEST(LENGTH(`bfcol_0`), 10),
15+
'-'
16+
) AS `bfcol_1`
17+
FROM `bfcte_0`
18+
)
19+
SELECT
20+
`bfcol_1` AS `string_col`
21+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,24 @@ def test_str_get(scalar_types_df: bpd.DataFrame, snapshot):
466466
snapshot.assert_match(sql, "out.sql")
467467

468468

469+
def test_str_pad(scalar_types_df: bpd.DataFrame, snapshot):
470+
bf_df = scalar_types_df[["string_col"]]
471+
sql = _apply_unary_op(
472+
bf_df, ops.StrPadOp(length=10, fillchar="-", side="left"), "string_col"
473+
)
474+
snapshot.assert_match(sql, "left.sql")
475+
476+
sql = _apply_unary_op(
477+
bf_df, ops.StrPadOp(length=10, fillchar="-", side="right"), "string_col"
478+
)
479+
snapshot.assert_match(sql, "right.sql")
480+
481+
sql = _apply_unary_op(
482+
bf_df, ops.StrPadOp(length=10, fillchar="-", side="both"), "string_col"
483+
)
484+
snapshot.assert_match(sql, "both.sql")
485+
486+
469487
def test_str_slice(scalar_types_df: bpd.DataFrame, snapshot):
470488
bf_df = scalar_types_df[["string_col"]]
471489
sql = _apply_unary_op(bf_df, ops.StrSliceOp(1, 3), "string_col")
@@ -506,6 +524,21 @@ def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot):
506524
snapshot.assert_match(sql, "out.sql")
507525

508526

527+
def test_str_find(scalar_types_df: bpd.DataFrame, snapshot):
528+
bf_df = scalar_types_df[["string_col"]]
529+
sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=None), "string_col")
530+
snapshot.assert_match(sql, "out.sql")
531+
532+
sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=None), "string_col")
533+
snapshot.assert_match(sql, "out_with_start.sql")
534+
535+
sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=5), "string_col")
536+
snapshot.assert_match(sql, "out_with_end.sql")
537+
538+
sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=5), "string_col")
539+
snapshot.assert_match(sql, "out_with_start_and_end.sql")
540+
541+
509542
def test_strip(scalar_types_df: bpd.DataFrame, snapshot):
510543
bf_df = scalar_types_df[["string_col"]]
511544
sql = _apply_unary_op(bf_df, ops.StrStripOp(" "), "string_col")

0 commit comments

Comments
 (0)