Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 97 additions & 8 deletions bigframes/core/compile/sqlglot/expressions/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,96 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
)


@UNARY_OP_REGISTRATION.register(ops.StrContainsOp)
def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression:
return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%"))


@UNARY_OP_REGISTRATION.register(ops.StrContainsRegexOp)
def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression:
return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat))


@UNARY_OP_REGISTRATION.register(ops.StrContainsOp)
def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression:
return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%"))
@UNARY_OP_REGISTRATION.register(ops.StrExtractOp)
def _(op: ops.StrExtractOp, expr: TypedExpr) -> sge.Expression:
return sge.RegexpExtract(
this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n)
)


@UNARY_OP_REGISTRATION.register(ops.StrFindOp)
def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression:
# INSTR is 1-based, so we need to adjust the start position.
start = sge.convert(op.start + 1) if op.start is not None else sge.convert(1)
if op.end is not None:
# BigQuery's INSTR doesn't support `end`, so we need to use SUBSTR.
return sge.func(
"INSTR",
sge.Substring(
this=expr.expr,
start=start,
length=sge.convert(op.end - (op.start or 0)),
),
sge.convert(op.substr),
) - sge.convert(1)
else:
return sge.func(
"INSTR",
expr.expr,
sge.convert(op.substr),
start,
) - sge.convert(1)


@UNARY_OP_REGISTRATION.register(ops.StrLstripOp)
def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression:
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")


@UNARY_OP_REGISTRATION.register(ops.StrPadOp)
def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression:
pad_length = sge.func(
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
)
if op.side == "left":
return sge.func(
"LPAD",
expr.expr,
pad_length,
sge.convert(op.fillchar),
)
elif op.side == "right":
return sge.func(
"RPAD",
expr.expr,
pad_length,
sge.convert(op.fillchar),
)
else: # side == both
lpad_amount = sge.Cast(
this=sge.func(
"SAFE_DIVIDE",
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
sge.convert(2),
),
to="INT64",
) + sge.Length(this=expr.expr)
return sge.func(
"RPAD",
sge.func(
"LPAD",
expr.expr,
lpad_amount,
sge.convert(op.fillchar),
),
pad_length,
sge.convert(op.fillchar),
)


@UNARY_OP_REGISTRATION.register(ops.StrRepeatOp)
def _(op: ops.StrRepeatOp, expr: TypedExpr) -> sge.Expression:
return sge.Repeat(this=expr.expr, times=sge.convert(op.repeats))


@UNARY_OP_REGISTRATION.register(ops.date_op)
Expand Down Expand Up @@ -444,11 +526,6 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="MONTH"), expression=expr.expr)


@UNARY_OP_REGISTRATION.register(ops.StrLstripOp)
def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression:
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")


@UNARY_OP_REGISTRATION.register(ops.neg_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Neg(this=expr.expr)
Expand Down Expand Up @@ -484,6 +561,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="QUARTER"), expression=expr.expr)


@UNARY_OP_REGISTRATION.register(ops.ReplaceStrOp)
def _(op: ops.ReplaceStrOp, expr: TypedExpr) -> sge.Expression:
return sge.func("REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl))


@UNARY_OP_REGISTRATION.register(ops.RegexReplaceStrOp)
def _(op: ops.RegexReplaceStrOp, expr: TypedExpr) -> sge.Expression:
return sge.func(
"REGEXP_REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl)
)


@UNARY_OP_REGISTRATION.register(ops.reverse_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("REVERSE", expr.expr)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
REGEXP_REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
REGEXP_EXTRACT(`bfcol_0`, '([a-z]*)') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
RPAD(
LPAD(
`bfcol_0`,
CAST(SAFE_DIVIDE(GREATEST(LENGTH(`bfcol_0`), 10) - LENGTH(`bfcol_0`), 2) AS INT64) + LENGTH(`bfcol_0`),
'-'
),
GREATEST(LENGTH(`bfcol_0`), 10),
'-'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
REPEAT(`bfcol_0`, 2) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,18 @@ def test_quarter(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_replace_str(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.ReplaceStrOp("e", "a"), "string_col")
snapshot.assert_match(sql, "out.sql")


def test_regex_replace_str(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.RegexReplaceStrOp(r"e", "a"), "string_col")
snapshot.assert_match(sql, "out.sql")


def test_reverse(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.reverse_op, "string_col")
Expand Down Expand Up @@ -466,6 +478,24 @@ def test_str_get(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_str_pad(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(
bf_df, ops.StrPadOp(length=10, fillchar="-", side="left"), "string_col"
)
snapshot.assert_match(sql, "left.sql")

sql = _apply_unary_op(
bf_df, ops.StrPadOp(length=10, fillchar="-", side="right"), "string_col"
)
snapshot.assert_match(sql, "right.sql")

sql = _apply_unary_op(
bf_df, ops.StrPadOp(length=10, fillchar="-", side="both"), "string_col"
)
snapshot.assert_match(sql, "both.sql")


def test_str_slice(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.StrSliceOp(1, 3), "string_col")
Expand Down Expand Up @@ -506,6 +536,34 @@ def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_str_extract(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.StrExtractOp(r"([a-z]*)", 1), "string_col")

snapshot.assert_match(sql, "out.sql")


def test_str_repeat(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.StrRepeatOp(2), "string_col")
snapshot.assert_match(sql, "out.sql")


def test_str_find(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=None), "string_col")
snapshot.assert_match(sql, "out.sql")

sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=None), "string_col")
snapshot.assert_match(sql, "out_with_start.sql")

sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=5), "string_col")
snapshot.assert_match(sql, "out_with_end.sql")

sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=5), "string_col")
snapshot.assert_match(sql, "out_with_start_and_end.sql")


def test_strip(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_unary_op(bf_df, ops.StrStripOp(" "), "string_col")
Expand Down