Skip to content

Commit c0b54f0

Browse files
feat: Support string matching in local executor (#2032)
1 parent ba0d23b commit c0b54f0

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,34 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
301301
assert isinstance(op, string_ops.StrConcatOp)
302302
return pl.concat_str(l_input, r_input)
303303

304+
@compile_op.register(string_ops.StrContainsOp)
305+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
306+
assert isinstance(op, string_ops.StrContainsOp)
307+
return input.str.contains(pattern=op.pat, literal=True)
308+
309+
@compile_op.register(string_ops.StrContainsRegexOp)
310+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
311+
assert isinstance(op, string_ops.StrContainsRegexOp)
312+
return input.str.contains(pattern=op.pat, literal=False)
313+
314+
@compile_op.register(string_ops.StartsWithOp)
315+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
316+
assert isinstance(op, string_ops.StartsWithOp)
317+
if len(op.pat) == 1:
318+
return input.str.starts_with(op.pat[0])
319+
else:
320+
return pl.any_horizontal(
321+
*(input.str.starts_with(pat) for pat in op.pat)
322+
)
323+
324+
@compile_op.register(string_ops.EndsWithOp)
325+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
326+
assert isinstance(op, string_ops.EndsWithOp)
327+
if len(op.pat) == 1:
328+
return input.str.ends_with(op.pat[0])
329+
else:
330+
return pl.any_horizontal(*(input.str.ends_with(pat) for pat in op.pat))
331+
304332
@compile_op.register(dt_ops.StrftimeOp)
305333
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
306334
assert isinstance(op, dt_ops.StrftimeOp)

bigframes/session/polars_executor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
from bigframes.core import array_value, bigframe_node, expression, local_data, nodes
2222
import bigframes.operations
2323
from bigframes.operations import aggregations as agg_ops
24-
from bigframes.operations import bool_ops, comparison_ops, generic_ops, numeric_ops
24+
from bigframes.operations import (
25+
bool_ops,
26+
comparison_ops,
27+
generic_ops,
28+
numeric_ops,
29+
string_ops,
30+
)
2531
from bigframes.session import executor, semi_executor
2632

2733
if TYPE_CHECKING:
@@ -69,6 +75,10 @@
6975
generic_ops.IsInOp,
7076
generic_ops.IsNullOp,
7177
generic_ops.NotNullOp,
78+
string_ops.StartsWithOp,
79+
string_ops.EndsWithOp,
80+
string_ops.StrContainsOp,
81+
string_ops.StrContainsRegexOp,
7282
)
7383
_COMPATIBLE_AGG_OPS = (
7484
agg_ops.SizeOp,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from bigframes.core import array_value
18+
import bigframes.operations as ops
19+
from bigframes.session import polars_executor
20+
from bigframes.testing.engine_utils import assert_equivalence_execution
21+
22+
pytest.importorskip("polars")
23+
24+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
25+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
26+
27+
28+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
29+
def test_engines_str_contains(scalars_array_value: array_value.ArrayValue, engine):
30+
arr, _ = scalars_array_value.compute_values(
31+
[
32+
ops.StrContainsOp("(?i)hEllo").as_expr("string_col"),
33+
ops.StrContainsOp("Hello").as_expr("string_col"),
34+
ops.StrContainsOp("T").as_expr("string_col"),
35+
ops.StrContainsOp(".*").as_expr("string_col"),
36+
]
37+
)
38+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
39+
40+
41+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
42+
def test_engines_str_contains_regex(
43+
scalars_array_value: array_value.ArrayValue, engine
44+
):
45+
arr, _ = scalars_array_value.compute_values(
46+
[
47+
ops.StrContainsRegexOp("(?i)hEllo").as_expr("string_col"),
48+
ops.StrContainsRegexOp("Hello").as_expr("string_col"),
49+
ops.StrContainsRegexOp("T").as_expr("string_col"),
50+
ops.StrContainsRegexOp(".*").as_expr("string_col"),
51+
]
52+
)
53+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
54+
55+
56+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
57+
def test_engines_str_startswith(scalars_array_value: array_value.ArrayValue, engine):
58+
arr, _ = scalars_array_value.compute_values(
59+
[
60+
ops.StartsWithOp("He").as_expr("string_col"),
61+
ops.StartsWithOp("llo").as_expr("string_col"),
62+
ops.StartsWithOp(("He", "T", "ca")).as_expr("string_col"),
63+
]
64+
)
65+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
66+
67+
68+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
69+
def test_engines_str_endswith(scalars_array_value: array_value.ArrayValue, engine):
70+
arr, _ = scalars_array_value.compute_values(
71+
[
72+
ops.EndsWithOp("!").as_expr("string_col"),
73+
ops.EndsWithOp("llo").as_expr("string_col"),
74+
ops.EndsWithOp(("He", "T", "ca")).as_expr("string_col"),
75+
]
76+
)
77+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)