diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index ddaf04ae97..a5cffdc10a 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -177,14 +177,96 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StrContainsOp) +def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression: + return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%")) + + @UNARY_OP_REGISTRATION.register(ops.StrContainsRegexOp) def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat)) -@UNARY_OP_REGISTRATION.register(ops.StrContainsOp) -def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression: - return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%")) +@UNARY_OP_REGISTRATION.register(ops.StrExtractOp) +def _(op: ops.StrExtractOp, expr: TypedExpr) -> sge.Expression: + return sge.RegexpExtract( + this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n) + ) + + +@UNARY_OP_REGISTRATION.register(ops.StrFindOp) +def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression: + # INSTR is 1-based, so we need to adjust the start position. + start = sge.convert(op.start + 1) if op.start is not None else sge.convert(1) + if op.end is not None: + # BigQuery's INSTR doesn't support `end`, so we need to use SUBSTR. + return sge.func( + "INSTR", + sge.Substring( + this=expr.expr, + start=start, + length=sge.convert(op.end - (op.start or 0)), + ), + sge.convert(op.substr), + ) - sge.convert(1) + else: + return sge.func( + "INSTR", + expr.expr, + sge.convert(op.substr), + start, + ) - sge.convert(1) + + +@UNARY_OP_REGISTRATION.register(ops.StrLstripOp) +def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression: + return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") + + +@UNARY_OP_REGISTRATION.register(ops.StrPadOp) +def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression: + pad_length = sge.func( + "GREATEST", sge.Length(this=expr.expr), sge.convert(op.length) + ) + if op.side == "left": + return sge.func( + "LPAD", + expr.expr, + pad_length, + sge.convert(op.fillchar), + ) + elif op.side == "right": + return sge.func( + "RPAD", + expr.expr, + pad_length, + sge.convert(op.fillchar), + ) + else: # side == both + lpad_amount = sge.Cast( + this=sge.func( + "SAFE_DIVIDE", + sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)), + sge.convert(2), + ), + to="INT64", + ) + sge.Length(this=expr.expr) + return sge.func( + "RPAD", + sge.func( + "LPAD", + expr.expr, + lpad_amount, + sge.convert(op.fillchar), + ), + pad_length, + sge.convert(op.fillchar), + ) + + +@UNARY_OP_REGISTRATION.register(ops.StrRepeatOp) +def _(op: ops.StrRepeatOp, expr: TypedExpr) -> sge.Expression: + return sge.Repeat(this=expr.expr, times=sge.convert(op.repeats)) @UNARY_OP_REGISTRATION.register(ops.date_op) @@ -444,11 +526,6 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="MONTH"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.StrLstripOp) -def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression: - return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") - - @UNARY_OP_REGISTRATION.register(ops.neg_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Neg(this=expr.expr) @@ -484,6 +561,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="QUARTER"), expression=expr.expr) +@UNARY_OP_REGISTRATION.register(ops.ReplaceStrOp) +def _(op: ops.ReplaceStrOp, expr: TypedExpr) -> sge.Expression: + return sge.func("REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl)) + + +@UNARY_OP_REGISTRATION.register(ops.RegexReplaceStrOp) +def _(op: ops.RegexReplaceStrOp, expr: TypedExpr) -> sge.Expression: + return sge.func( + "REGEXP_REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl) + ) + + @UNARY_OP_REGISTRATION.register(ops.reverse_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.func("REVERSE", expr.expr) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql new file mode 100644 index 0000000000..149df6706c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REGEXP_REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql new file mode 100644 index 0000000000..3bd7e0e47e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql new file mode 100644 index 0000000000..a7fac093e2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REGEXP_EXTRACT(`bfcol_0`, '([a-z]*)') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql new file mode 100644 index 0000000000..dfc100e413 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql new file mode 100644 index 0000000000..78edf662b9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql new file mode 100644 index 0000000000..d0dfc11a53 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql new file mode 100644 index 0000000000..a91ab32946 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql new file mode 100644 index 0000000000..4701b0237a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql @@ -0,0 +1,21 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + RPAD( + LPAD( + `bfcol_0`, + CAST(SAFE_DIVIDE(GREATEST(LENGTH(`bfcol_0`), 10) - LENGTH(`bfcol_0`), 2) AS INT64) + LENGTH(`bfcol_0`), + '-' + ), + GREATEST(LENGTH(`bfcol_0`), 10), + '-' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql new file mode 100644 index 0000000000..ee95900b3e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql new file mode 100644 index 0000000000..17e59c553f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql new file mode 100644 index 0000000000..1c94cfafe2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REPEAT(`bfcol_0`, 2) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 4a5b586c77..5c51068ce7 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -431,6 +431,18 @@ def test_quarter(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_replace_str(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.ReplaceStrOp("e", "a"), "string_col") + snapshot.assert_match(sql, "out.sql") + + +def test_regex_replace_str(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.RegexReplaceStrOp(r"e", "a"), "string_col") + snapshot.assert_match(sql, "out.sql") + + def test_reverse(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.reverse_op, "string_col") @@ -466,6 +478,24 @@ def test_str_get(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_str_pad(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="left"), "string_col" + ) + snapshot.assert_match(sql, "left.sql") + + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="right"), "string_col" + ) + snapshot.assert_match(sql, "right.sql") + + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="both"), "string_col" + ) + snapshot.assert_match(sql, "both.sql") + + def test_str_slice(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrSliceOp(1, 3), "string_col") @@ -506,6 +536,34 @@ def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_str_extract(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrExtractOp(r"([a-z]*)", 1), "string_col") + + snapshot.assert_match(sql, "out.sql") + + +def test_str_repeat(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrRepeatOp(2), "string_col") + snapshot.assert_match(sql, "out.sql") + + +def test_str_find(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=None), "string_col") + snapshot.assert_match(sql, "out.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=None), "string_col") + snapshot.assert_match(sql, "out_with_start.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=5), "string_col") + snapshot.assert_match(sql, "out_with_end.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=5), "string_col") + snapshot.assert_match(sql, "out_with_start_and_end.sql") + + def test_strip(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrStripOp(" "), "string_col")