Skip to content

Commit b9d0122

Browse files
authored
Merge branch 'main' into udf-add-tests1
2 parents f3627b7 + 6bf06a7 commit b9d0122

File tree

19 files changed

+423
-21
lines changed

19 files changed

+423
-21
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
import typing
1819

1920
import pandas as pd
@@ -292,6 +293,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
292293
return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr)
293294

294295

296+
@UNARY_OP_REGISTRATION.register(ops.EndsWithOp)
297+
def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression:
298+
if not op.pat:
299+
return sge.false()
300+
301+
def to_endswith(pat: str) -> sge.Expression:
302+
return sge.func("ENDS_WITH", expr.expr, sge.convert(pat))
303+
304+
conditions = [to_endswith(pat) for pat in op.pat]
305+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
306+
307+
295308
@UNARY_OP_REGISTRATION.register(ops.exp_op)
296309
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
297310
return sge.Case(
@@ -344,6 +357,27 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
344357
return sge.func("ST_BOUNDARY", expr.expr)
345358

346359

360+
@UNARY_OP_REGISTRATION.register(ops.GeoStBufferOp)
361+
def _(op: ops.GeoStBufferOp, expr: TypedExpr) -> sge.Expression:
362+
return sge.func(
363+
"ST_BUFFER",
364+
expr.expr,
365+
sge.convert(op.buffer_radius),
366+
sge.convert(op.num_seg_quarter_circle),
367+
sge.convert(op.use_spheroid),
368+
)
369+
370+
371+
@UNARY_OP_REGISTRATION.register(ops.geo_st_centroid_op)
372+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
373+
return sge.func("ST_CENTROID", expr.expr)
374+
375+
376+
@UNARY_OP_REGISTRATION.register(ops.geo_st_convexhull_op)
377+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
378+
return sge.func("ST_CONVEXHULL", expr.expr)
379+
380+
347381
@UNARY_OP_REGISTRATION.register(ops.geo_st_geogfromtext_op)
348382
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
349383
return sge.func("SAFE.ST_GEOGFROMTEXT", expr.expr)
@@ -516,6 +550,17 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
516550
return sge.Lower(this=expr.expr)
517551

518552

553+
@UNARY_OP_REGISTRATION.register(ops.MapOp)
554+
def _(op: ops.MapOp, expr: TypedExpr) -> sge.Expression:
555+
return sge.Case(
556+
this=expr.expr,
557+
ifs=[
558+
sge.If(this=sge.convert(key), true=sge.convert(value))
559+
for key, value in op.mappings
560+
],
561+
)
562+
563+
519564
@UNARY_OP_REGISTRATION.register(ops.minute_op)
520565
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
521566
return sge.Extract(this=sge.Identifier(this="MINUTE"), expression=expr.expr)
@@ -601,6 +646,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
601646
)
602647

603648

649+
@UNARY_OP_REGISTRATION.register(ops.StartsWithOp)
650+
def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression:
651+
if not op.pat:
652+
return sge.false()
653+
654+
def to_startswith(pat: str) -> sge.Expression:
655+
return sge.func("STARTS_WITH", expr.expr, sge.convert(pat))
656+
657+
conditions = [to_startswith(pat) for pat in op.pat]
658+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
659+
660+
604661
@UNARY_OP_REGISTRATION.register(ops.StrStripOp)
605662
def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression:
606663
return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr)
@@ -624,6 +681,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
624681
)
625682

626683

684+
@UNARY_OP_REGISTRATION.register(ops.StringSplitOp)
685+
def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression:
686+
return sge.Split(this=expr.expr, expression=sge.convert(op.pat))
687+
688+
627689
@UNARY_OP_REGISTRATION.register(ops.StrGetOp)
628690
def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression:
629691
return sge.Substring(
@@ -776,3 +838,31 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
776838
@UNARY_OP_REGISTRATION.register(ops.year_op)
777839
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
778840
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
841+
842+
843+
@UNARY_OP_REGISTRATION.register(ops.ZfillOp)
844+
def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
845+
return sge.Case(
846+
ifs=[
847+
sge.If(
848+
this=sge.EQ(
849+
this=sge.Substring(
850+
this=expr.expr, start=sge.convert(1), length=sge.convert(1)
851+
),
852+
expression=sge.convert("-"),
853+
),
854+
true=sge.Concat(
855+
expressions=[
856+
sge.convert("-"),
857+
sge.func(
858+
"LPAD",
859+
sge.Substring(this=expr.expr, start=sge.convert(1)),
860+
sge.convert(op.width - 1),
861+
sge.convert("0"),
862+
),
863+
]
864+
),
865+
)
866+
],
867+
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
868+
)

bigframes/dataframe.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,19 @@ def itertuples(
28282828
for item in df.itertuples(index=index, name=name):
28292829
yield item
28302830

2831+
def _apply_callable(self, condition):
2832+
"""Executes the possible callable condition as needed."""
2833+
if callable(condition):
2834+
# When it's a bigframes function.
2835+
if hasattr(condition, "bigframes_bigquery_function"):
2836+
return self.apply(condition, axis=1)
2837+
2838+
# When it's a plain Python function.
2839+
return condition(self)
2840+
2841+
# When it's not a callable.
2842+
return condition
2843+
28312844
def where(self, cond, other=None):
28322845
if isinstance(other, bigframes.series.Series):
28332846
raise ValueError("Seires is not a supported replacement type!")
@@ -2839,16 +2852,8 @@ def where(self, cond, other=None):
28392852

28402853
# Execute it with the DataFrame when cond or/and other is callable.
28412854
# It can be either a plain python function or remote/managed function.
2842-
if callable(cond):
2843-
if hasattr(cond, "bigframes_bigquery_function"):
2844-
cond = self.apply(cond, axis=1)
2845-
else:
2846-
cond = cond(self)
2847-
if callable(other):
2848-
if hasattr(other, "bigframes_bigquery_function"):
2849-
other = self.apply(other, axis=1)
2850-
else:
2851-
other = other(self)
2855+
cond = self._apply_callable(cond)
2856+
other = self._apply_callable(other)
28522857

28532858
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28542859
# No left join is needed when 'other' is None or constant.
@@ -2899,7 +2904,7 @@ def where(self, cond, other=None):
28992904
return result
29002905

29012906
def mask(self, cond, other=None):
2902-
return self.where(~cond, other=other)
2907+
return self.where(~self._apply_callable(cond), other=other)
29032908

29042909
def dropna(
29052910
self,

tests/system/large/functions/test_managed_function.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def float_parser(row):
965965
)
966966

967967

968-
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
968+
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
969969
try:
970970

971971
# The return type has to be bool type for callable where condition.
@@ -987,15 +987,15 @@ def is_sum_positive(a, b):
987987
pd_int64_df = scalars_pandas_df[int64_cols]
988988
pd_int64_df_filtered = pd_int64_df.dropna()
989989

990-
# Use callable condition in dataframe.where method.
990+
# Test callable condition in dataframe.where method.
991991
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992992
# Pandas doesn't support such case, use following as workaround.
993993
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994994

995995
# Ignore any dtype difference.
996996
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997997

998-
# Make sure the read_gbq_function path works for this function.
998+
# Make sure the read_gbq_function path works for dataframe.where method.
999999
is_sum_positive_ref = session.read_gbq_function(
10001000
function_name=is_sum_positive_mf.bigframes_bigquery_function
10011001
)
@@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
10121012
bf_result_gbq, pd_result_gbq, check_dtype=False
10131013
)
10141014

1015+
# Test callable condition in dataframe.mask method.
1016+
bf_result_gbq = bf_int64_df_filtered.mask(
1017+
is_sum_positive_ref, -bf_int64_df_filtered
1018+
).to_pandas()
1019+
pd_result_gbq = pd_int64_df_filtered.mask(
1020+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1021+
)
1022+
1023+
# Ignore any dtype difference.
1024+
pandas.testing.assert_frame_equal(
1025+
bf_result_gbq, pd_result_gbq, check_dtype=False
1026+
)
1027+
10151028
finally:
10161029
# Clean up the gcp assets created for the managed function.
10171030
cleanup_function_assets(
10181031
is_sum_positive_mf, session.bqclient, ignore_failures=False
10191032
)
10201033

10211034

1022-
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1035+
def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs):
10231036
try:
10241037

10251038
# The return type has to be bool type for callable where condition.
@@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
10411054
pd_int64_df = scalars_pandas_df[int64_cols]
10421055
pd_int64_df_filtered = pd_int64_df.dropna()
10431056

1044-
# Use callable condition in dataframe.where method.
1057+
# Test callable condition in dataframe.where method.
10451058
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
10461059
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
10471060

10481061
# Ignore any dtype difference.
10491062
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
10501063

1051-
# Make sure the read_gbq_function path works for this function.
1064+
# Make sure the read_gbq_function path works for dataframe.where method.
10521065
is_sum_positive_series_ref = session.read_gbq_function(
10531066
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
10541067
is_row_processor=True,
@@ -1070,6 +1083,19 @@ def func_for_other(x):
10701083
bf_result_gbq, pd_result_gbq, check_dtype=False
10711084
)
10721085

1086+
# Test callable condition in dataframe.mask method.
1087+
bf_result_gbq = bf_int64_df_filtered.mask(
1088+
is_sum_positive_series_ref, func_for_other
1089+
).to_pandas()
1090+
pd_result_gbq = pd_int64_df_filtered.mask(
1091+
is_sum_positive_series, func_for_other
1092+
)
1093+
1094+
# Ignore any dtype difference.
1095+
pandas.testing.assert_frame_equal(
1096+
bf_result_gbq, pd_result_gbq, check_dtype=False
1097+
)
1098+
10731099
finally:
10741100
# Clean up the gcp assets created for the managed function.
10751101
cleanup_function_assets(

tests/system/large/functions/test_remote_function.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,7 +2850,7 @@ def foo(x: int) -> int:
28502850

28512851

28522852
@pytest.mark.flaky(retries=2, delay=120)
2853-
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2853+
def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs):
28542854
try:
28552855

28562856
# The return type has to be bool type for callable where condition.
@@ -2873,14 +2873,22 @@ def is_sum_positive(a, b):
28732873
pd_int64_df = scalars_pandas_df[int64_cols]
28742874
pd_int64_df_filtered = pd_int64_df.dropna()
28752875

2876-
# Use callable condition in dataframe.where method.
2876+
# Test callable condition in dataframe.where method.
28772877
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
28782878
# Pandas doesn't support such case, use following as workaround.
28792879
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
28802880

28812881
# Ignore any dtype difference.
28822882
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
28832883

2884+
# Test callable condition in dataframe.mask method.
2885+
bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas()
2886+
# Pandas doesn't support such case, use following as workaround.
2887+
pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2888+
2889+
# Ignore any dtype difference.
2890+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2891+
28842892
finally:
28852893
# Clean up the gcp assets created for the remote function.
28862894
cleanup_function_assets(
@@ -2889,7 +2897,7 @@ def is_sum_positive(a, b):
28892897

28902898

28912899
@pytest.mark.flaky(retries=2, delay=120)
2892-
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2900+
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
28932901
try:
28942902

28952903
# The return type has to be bool type for callable where condition.
@@ -2916,7 +2924,7 @@ def is_sum_positive_series(s):
29162924
def func_for_other(x):
29172925
return -x
29182926

2919-
# Use callable condition in dataframe.where method.
2927+
# Test callable condition in dataframe.where method.
29202928
bf_result = bf_int64_df_filtered.where(
29212929
is_sum_positive_series, func_for_other
29222930
).to_pandas()
@@ -2925,6 +2933,15 @@ def func_for_other(x):
29252933
# Ignore any dtype difference.
29262934
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
29272935

2936+
# Test callable condition in dataframe.mask method.
2937+
bf_result = bf_int64_df_filtered.mask(
2938+
is_sum_positive_series_mf, func_for_other
2939+
).to_pandas()
2940+
pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other)
2941+
2942+
# Ignore any dtype difference.
2943+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2944+
29282945
finally:
29292946
# Clean up the gcp assets created for the remote function.
29302947
cleanup_function_assets(

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
406406
pandas.testing.assert_frame_equal(bf_result, pd_result)
407407

408408

409+
def test_mask_callable(scalars_df_index, scalars_pandas_df_index):
410+
def is_positive(x):
411+
return x > 0
412+
413+
bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
414+
pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]]
415+
bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas()
416+
pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1)
417+
418+
pandas.testing.assert_frame_equal(bf_result, pd_result)
419+
420+
409421
def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
410422
# Test when a dataframe has multi-columns.
411423
columns = ["int64_col", "float64_col"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FALSE AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `string_col`
13+
FROM `bfcte_1`

0 commit comments

Comments
 (0)