Skip to content

Commit cfa4b2a

Browse files
authored
chore: implement StartsWithOp, EndsWithOp, StringSplitOp and ZfillOp for sqlglot compilers (#2027)
1 parent 9ed0078 commit cfa4b2a

File tree

10 files changed

+202
-0
lines changed

10 files changed

+202
-0
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
import typing
1819

1920
import pandas as pd
@@ -292,6 +293,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
292293
return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr)
293294

294295

296+
@UNARY_OP_REGISTRATION.register(ops.EndsWithOp)
297+
def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression:
298+
if not op.pat:
299+
return sge.false()
300+
301+
def to_endswith(pat: str) -> sge.Expression:
302+
return sge.func("ENDS_WITH", expr.expr, sge.convert(pat))
303+
304+
conditions = [to_endswith(pat) for pat in op.pat]
305+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
306+
307+
295308
@UNARY_OP_REGISTRATION.register(ops.exp_op)
296309
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
297310
return sge.Case(
@@ -633,6 +646,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
633646
)
634647

635648

649+
@UNARY_OP_REGISTRATION.register(ops.StartsWithOp)
650+
def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression:
651+
if not op.pat:
652+
return sge.false()
653+
654+
def to_startswith(pat: str) -> sge.Expression:
655+
return sge.func("STARTS_WITH", expr.expr, sge.convert(pat))
656+
657+
conditions = [to_startswith(pat) for pat in op.pat]
658+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
659+
660+
636661
@UNARY_OP_REGISTRATION.register(ops.StrStripOp)
637662
def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression:
638663
return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr)
@@ -656,6 +681,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
656681
)
657682

658683

684+
@UNARY_OP_REGISTRATION.register(ops.StringSplitOp)
685+
def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression:
686+
return sge.Split(this=expr.expr, expression=sge.convert(op.pat))
687+
688+
659689
@UNARY_OP_REGISTRATION.register(ops.StrGetOp)
660690
def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression:
661691
return sge.Substring(
@@ -808,3 +838,31 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
808838
@UNARY_OP_REGISTRATION.register(ops.year_op)
809839
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
810840
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
841+
842+
843+
@UNARY_OP_REGISTRATION.register(ops.ZfillOp)
844+
def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
845+
return sge.Case(
846+
ifs=[
847+
sge.If(
848+
this=sge.EQ(
849+
this=sge.Substring(
850+
this=expr.expr, start=sge.convert(1), length=sge.convert(1)
851+
),
852+
expression=sge.convert("-"),
853+
),
854+
true=sge.Concat(
855+
expressions=[
856+
sge.convert("-"),
857+
sge.func(
858+
"LPAD",
859+
sge.Substring(this=expr.expr, start=sge.convert(1)),
860+
sge.convert(op.width - 1),
861+
sge.convert("0"),
862+
),
863+
]
864+
),
865+
)
866+
],
867+
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
868+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FALSE AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ENDS_WITH(`bfcol_0`, 'ab') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
STARTS_WITH(`bfcol_0`, 'ab') OR STARTS_WITH(`bfcol_0`, 'cd') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FALSE AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
STARTS_WITH(`bfcol_0`, 'ab') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
SPLIT(`bfcol_0`, ',') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN SUBSTRING(`bfcol_0`, 1, 1) = '-'
10+
THEN CONCAT('-', LPAD(SUBSTRING(`bfcol_0`, 1), 9, '0'))
11+
ELSE LPAD(`bfcol_0`, 10, '0')
12+
END AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `string_col`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
125125
snapshot.assert_match(sql, "out.sql")
126126

127127

128+
def test_endswith(scalar_types_df: bpd.DataFrame, snapshot):
129+
bf_df = scalar_types_df[["string_col"]]
130+
sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab",)), "string_col")
131+
snapshot.assert_match(sql, "single_pattern.sql")
132+
133+
sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab", "cd")), "string_col")
134+
snapshot.assert_match(sql, "multiple_patterns.sql")
135+
136+
sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=()), "string_col")
137+
snapshot.assert_match(sql, "no_pattern.sql")
138+
139+
128140
def test_exp(scalar_types_df: bpd.DataFrame, snapshot):
129141
bf_df = scalar_types_df[["float64_col"]]
130142
sql = _apply_unary_op(bf_df, ops.exp_op, "float64_col")
@@ -501,6 +513,18 @@ def test_sqrt(scalar_types_df: bpd.DataFrame, snapshot):
501513
snapshot.assert_match(sql, "out.sql")
502514

503515

516+
def test_startswith(scalar_types_df: bpd.DataFrame, snapshot):
517+
bf_df = scalar_types_df[["string_col"]]
518+
sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab",)), "string_col")
519+
snapshot.assert_match(sql, "single_pattern.sql")
520+
521+
sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab", "cd")), "string_col")
522+
snapshot.assert_match(sql, "multiple_patterns.sql")
523+
524+
sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=()), "string_col")
525+
snapshot.assert_match(sql, "no_pattern.sql")
526+
527+
504528
def test_str_get(scalar_types_df: bpd.DataFrame, snapshot):
505529
bf_df = scalar_types_df[["string_col"]]
506530
sql = _apply_unary_op(bf_df, ops.StrGetOp(1), "string_col")
@@ -650,6 +674,12 @@ def test_sinh(scalar_types_df: bpd.DataFrame, snapshot):
650674
snapshot.assert_match(sql, "out.sql")
651675

652676

677+
def test_string_split(scalar_types_df: bpd.DataFrame, snapshot):
678+
bf_df = scalar_types_df[["string_col"]]
679+
sql = _apply_unary_op(bf_df, ops.StringSplitOp(pat=","), "string_col")
680+
snapshot.assert_match(sql, "out.sql")
681+
682+
653683
def test_tan(scalar_types_df: bpd.DataFrame, snapshot):
654684
bf_df = scalar_types_df[["float64_col"]]
655685
sql = _apply_unary_op(bf_df, ops.tan_op, "float64_col")
@@ -790,3 +820,9 @@ def test_year(scalar_types_df: bpd.DataFrame, snapshot):
790820
sql = _apply_unary_op(bf_df, ops.year_op, "timestamp_col")
791821

792822
snapshot.assert_match(sql, "out.sql")
823+
824+
825+
def test_zfill(scalar_types_df: bpd.DataFrame, snapshot):
826+
bf_df = scalar_types_df[["string_col"]]
827+
sql = _apply_unary_op(bf_df, ops.ZfillOp(width=10), "string_col")
828+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)