Skip to content

Commit 9026349

Browse files
committed
chore: implement ReplaceStrOp and RegexReplaceStrOp
1 parent 6b1c4bc commit 9026349

File tree

4 files changed

+100
-49
lines changed

4 files changed

+100
-49
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
177177
)
178178

179179

180+
@UNARY_OP_REGISTRATION.register(ops.StrContainsOp)
181+
def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression:
182+
return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%"))
183+
184+
180185
@UNARY_OP_REGISTRATION.register(ops.StrContainsRegexOp)
181186
def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression:
182187
return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat))
@@ -213,15 +218,57 @@ def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression:
213218
) - sge.convert(1)
214219

215220

216-
@UNARY_OP_REGISTRATION.register(ops.StrContainsOp)
217-
def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression:
218-
return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%"))
221+
@UNARY_OP_REGISTRATION.register(ops.StrLstripOp)
222+
def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression:
223+
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
224+
225+
226+
@UNARY_OP_REGISTRATION.register(ops.StrPadOp)
227+
def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression:
228+
pad_length = sge.func(
229+
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
230+
)
231+
if op.side == "left":
232+
return sge.func(
233+
"LPAD",
234+
expr.expr,
235+
pad_length,
236+
sge.convert(op.fillchar),
237+
)
238+
elif op.side == "right":
239+
return sge.func(
240+
"RPAD",
241+
expr.expr,
242+
pad_length,
243+
sge.convert(op.fillchar),
244+
)
245+
else: # side == both
246+
lpad_amount = sge.Cast(
247+
this=sge.func(
248+
"SAFE_DIVIDE",
249+
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
250+
sge.convert(2),
251+
),
252+
to="INT64",
253+
) + sge.Length(this=expr.expr)
254+
return sge.func(
255+
"RPAD",
256+
sge.func(
257+
"LPAD",
258+
expr.expr,
259+
lpad_amount,
260+
sge.convert(op.fillchar),
261+
),
262+
pad_length,
263+
sge.convert(op.fillchar),
264+
)
219265

220266

221267
@UNARY_OP_REGISTRATION.register(ops.StrRepeatOp)
222268
def _(op: ops.StrRepeatOp, expr: TypedExpr) -> sge.Expression:
223269
return sge.Repeat(this=expr.expr, times=sge.convert(op.repeats))
224270

271+
225272
@UNARY_OP_REGISTRATION.register(ops.date_op)
226273
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
227274
return sge.Date(this=expr.expr)
@@ -479,52 +526,6 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
479526
return sge.Extract(this=sge.Identifier(this="MONTH"), expression=expr.expr)
480527

481528

482-
@UNARY_OP_REGISTRATION.register(ops.StrLstripOp)
483-
def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression:
484-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
485-
486-
487-
@UNARY_OP_REGISTRATION.register(ops.StrPadOp)
488-
def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression:
489-
pad_length = sge.func(
490-
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
491-
)
492-
if op.side == "left":
493-
return sge.func(
494-
"LPAD",
495-
expr.expr,
496-
pad_length,
497-
sge.convert(op.fillchar),
498-
)
499-
elif op.side == "right":
500-
return sge.func(
501-
"RPAD",
502-
expr.expr,
503-
pad_length,
504-
sge.convert(op.fillchar),
505-
)
506-
else: # side == both
507-
lpad_amount = sge.Cast(
508-
this=sge.func(
509-
"SAFE_DIVIDE",
510-
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
511-
sge.convert(2),
512-
),
513-
to="INT64",
514-
) + sge.Length(this=expr.expr)
515-
return sge.func(
516-
"RPAD",
517-
sge.func(
518-
"LPAD",
519-
expr.expr,
520-
lpad_amount,
521-
sge.convert(op.fillchar),
522-
),
523-
pad_length,
524-
sge.convert(op.fillchar),
525-
)
526-
527-
528529
@UNARY_OP_REGISTRATION.register(ops.neg_op)
529530
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
530531
return sge.Neg(this=expr.expr)
@@ -560,6 +561,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
560561
return sge.Extract(this=sge.Identifier(this="QUARTER"), expression=expr.expr)
561562

562563

564+
@UNARY_OP_REGISTRATION.register(ops.ReplaceStrOp)
565+
def _(op: ops.ReplaceStrOp, expr: TypedExpr) -> sge.Expression:
566+
return sge.func("REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl))
567+
568+
569+
@UNARY_OP_REGISTRATION.register(ops.RegexReplaceStrOp)
570+
def _(op: ops.RegexReplaceStrOp, expr: TypedExpr) -> sge.Expression:
571+
return sge.func(
572+
"REGEXP_REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl)
573+
)
574+
575+
563576
@UNARY_OP_REGISTRATION.register(ops.reverse_op)
564577
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
565578
return sge.func("REVERSE", expr.expr)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
REGEXP_REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,18 @@ def test_quarter(scalar_types_df: bpd.DataFrame, snapshot):
431431
snapshot.assert_match(sql, "out.sql")
432432

433433

434+
def test_replace_str(scalar_types_df: bpd.DataFrame, snapshot):
435+
bf_df = scalar_types_df[["string_col"]]
436+
sql = _apply_unary_op(bf_df, ops.ReplaceStrOp("e", "a"), "string_col")
437+
snapshot.assert_match(sql, "out.sql")
438+
439+
440+
def test_regex_replace_str(scalar_types_df: bpd.DataFrame, snapshot):
441+
bf_df = scalar_types_df[["string_col"]]
442+
sql = _apply_unary_op(bf_df, ops.RegexReplaceStrOp(r"e", "a"), "string_col")
443+
snapshot.assert_match(sql, "out.sql")
444+
445+
434446
def test_reverse(scalar_types_df: bpd.DataFrame, snapshot):
435447
bf_df = scalar_types_df[["string_col"]]
436448
sql = _apply_unary_op(bf_df, ops.reverse_op, "string_col")

0 commit comments

Comments
 (0)