Skip to content

Commit d38e42c

Browse files
fix: add warnings for duplicated or conflicting type hints in bigfram… (#1956)
* fix: add warnings for duplicated or conflicting type hints in bigframes function * only warm conflict * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 4ea0e90 commit d38e42c

File tree

5 files changed

+131
-25
lines changed

5 files changed

+131
-25
lines changed

bigframes/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
103103
"""Remote Function and Managed UDF with axis=1 preview."""
104104

105105

106+
class FunctionConflictTypeHintWarning(UserWarning):
107+
"""Conflicting type hints in a BigFrames function."""
108+
109+
106110
class FunctionPackageVersionWarning(PreviewWarning):
107111
"""
108112
Managed UDF package versions for Numpy, Pandas, and Pyarrow may not

bigframes/functions/_function_session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,23 @@ def wrapper(func):
536536
if input_types is not None:
537537
if not isinstance(input_types, collections.abc.Sequence):
538538
input_types = [input_types]
539+
if _utils.has_conflict_input_type(py_sig, input_types):
540+
msg = bfe.format_message(
541+
"Conflicting input types detected, using the one from the decorator."
542+
)
543+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
539544
py_sig = py_sig.replace(
540545
parameters=[
541546
par.replace(annotation=itype)
542547
for par, itype in zip(py_sig.parameters.values(), input_types)
543548
]
544549
)
545550
if output_type:
551+
if _utils.has_conflict_output_type(py_sig, output_type):
552+
msg = bfe.format_message(
553+
"Conflicting return type detected, using the one from the decorator."
554+
)
555+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
546556
py_sig = py_sig.replace(return_annotation=output_type)
547557

548558
# Try to get input types via type annotations.
@@ -838,13 +848,23 @@ def wrapper(func):
838848
if input_types is not None:
839849
if not isinstance(input_types, collections.abc.Sequence):
840850
input_types = [input_types]
851+
if _utils.has_conflict_input_type(py_sig, input_types):
852+
msg = bfe.format_message(
853+
"Conflicting input types detected, using the one from the decorator."
854+
)
855+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
841856
py_sig = py_sig.replace(
842857
parameters=[
843858
par.replace(annotation=itype)
844859
for par, itype in zip(py_sig.parameters.values(), input_types)
845860
]
846861
)
847862
if output_type:
863+
if _utils.has_conflict_output_type(py_sig, output_type):
864+
msg = bfe.format_message(
865+
"Conflicting return type detected, using the one from the decorator."
866+
)
867+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
848868
py_sig = py_sig.replace(return_annotation=output_type)
849869

850870
# The function will actually be receiving a pandas Series, but allow

bigframes/functions/_utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515

1616
import hashlib
17+
import inspect
1718
import json
1819
import sys
1920
import typing
20-
from typing import cast, Optional, Set
21+
from typing import Any, cast, Optional, Sequence, Set
2122
import warnings
2223

2324
import cloudpickle
@@ -290,3 +291,36 @@ def post_process(input):
290291
return bbq.json_extract_string_array(input, value_dtype=result_dtype)
291292

292293
return post_process
294+
295+
296+
def has_conflict_input_type(
297+
signature: inspect.Signature,
298+
input_types: Sequence[Any],
299+
) -> bool:
300+
"""Checks if the parameters have any conflict with the input_types."""
301+
params = list(signature.parameters.values())
302+
303+
if len(params) != len(input_types):
304+
return True
305+
306+
# Check for conflicts type hints.
307+
for i, param in enumerate(params):
308+
if param.annotation is not inspect.Parameter.empty:
309+
if param.annotation != input_types[i]:
310+
return True
311+
312+
# No conflicts were found after checking all parameters.
313+
return False
314+
315+
316+
def has_conflict_output_type(
317+
signature: inspect.Signature,
318+
output_type: Any,
319+
) -> bool:
320+
"""Checks if the return type annotation conflicts with the output_type."""
321+
return_annotation = signature.return_annotation
322+
323+
if return_annotation is inspect.Parameter.empty:
324+
return False
325+
326+
return return_annotation != output_type

tests/system/large/functions/test_managed_function.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
1517
import google.api_core.exceptions
1618
import pandas
1719
import pyarrow
@@ -31,12 +33,22 @@
3133
def test_managed_function_array_output(session, scalars_dfs, dataset_id):
3234
try:
3335

34-
@session.udf(
35-
dataset=dataset_id,
36-
name=prefixer.create_prefix(),
36+
with warnings.catch_warnings(record=True) as record:
37+
38+
@session.udf(
39+
dataset=dataset_id,
40+
name=prefixer.create_prefix(),
41+
)
42+
def featurize(x: int) -> list[float]:
43+
return [float(i) for i in [x, x + 1, x + 2]]
44+
45+
# No following conflict warning when there is no redundant type hints.
46+
input_type_warning = "Conflicting input types detected"
47+
return_type_warning = "Conflicting return type detected"
48+
assert not any(input_type_warning in str(warning.message) for warning in record)
49+
assert not any(
50+
return_type_warning in str(warning.message) for warning in record
3751
)
38-
def featurize(x: int) -> list[float]:
39-
return [float(i) for i in [x, x + 1, x + 2]]
4052

4153
scalars_df, scalars_pandas_df = scalars_dfs
4254

@@ -222,7 +234,10 @@ def add(x: int, y: int) -> int:
222234
def test_managed_function_series_combine_array_output(session, dataset_id, scalars_dfs):
223235
try:
224236

225-
def add_list(x: int, y: int) -> list[int]:
237+
# The type hints in this function's signature has conflicts. The
238+
# `input_types` and `output_type` arguments from udf decorator take
239+
# precedence and will be used instead.
240+
def add_list(x, y: bool) -> list[bool]:
226241
return [x, y]
227242

228243
scalars_df, scalars_pandas_df = scalars_dfs
@@ -234,9 +249,18 @@ def add_list(x: int, y: int) -> list[int]:
234249
# Make sure there are NA values in the test column.
235250
assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]])
236251

237-
add_list_managed_func = session.udf(
238-
dataset=dataset_id, name=prefixer.create_prefix()
239-
)(add_list)
252+
with warnings.catch_warnings(record=True) as record:
253+
add_list_managed_func = session.udf(
254+
input_types=[int, int],
255+
output_type=list[int],
256+
dataset=dataset_id,
257+
name=prefixer.create_prefix(),
258+
)(add_list)
259+
260+
input_type_warning = "Conflicting input types detected"
261+
assert any(input_type_warning in str(warning.message) for warning in record)
262+
return_type_warning = "Conflicting return type detected"
263+
assert any(return_type_warning in str(warning.message) for warning in record)
240264

241265
# After filtering out nulls the managed function application should work
242266
# similar to pandas.

tests/system/large/functions/test_remote_function.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -843,22 +843,31 @@ def test_remote_function_with_external_package_dependencies(
843843
):
844844
try:
845845

846-
def pd_np_foo(x):
846+
# The return type hint in this function's signature has conflict. The
847+
# `output_type` argument from remote_function decorator takes precedence
848+
# and will be used instead.
849+
def pd_np_foo(x) -> None:
847850
import numpy as mynp
848851
import pandas as mypd
849852

850853
return mypd.Series([x, mynp.sqrt(mynp.abs(x))]).sum()
851854

852-
# Create the remote function with the name provided explicitly
853-
pd_np_foo_remote = session.remote_function(
854-
input_types=[int],
855-
output_type=float,
856-
dataset=dataset_id,
857-
bigquery_connection=bq_cf_connection,
858-
reuse=False,
859-
packages=["numpy", "pandas >= 2.0.0"],
860-
cloud_function_service_account="default",
861-
)(pd_np_foo)
855+
with warnings.catch_warnings(record=True) as record:
856+
# Create the remote function with the name provided explicitly
857+
pd_np_foo_remote = session.remote_function(
858+
input_types=[int],
859+
output_type=float,
860+
dataset=dataset_id,
861+
bigquery_connection=bq_cf_connection,
862+
reuse=False,
863+
packages=["numpy", "pandas >= 2.0.0"],
864+
cloud_function_service_account="default",
865+
)(pd_np_foo)
866+
867+
input_type_warning = "Conflicting input types detected"
868+
assert not any(input_type_warning in str(warning.message) for warning in record)
869+
return_type_warning = "Conflicting return type detected"
870+
assert any(return_type_warning in str(warning.message) for warning in record)
862871

863872
# The behavior of the created remote function should be as expected
864873
scalars_df, scalars_pandas_df = scalars_dfs
@@ -1999,10 +2008,25 @@ def test_remote_function_unnamed_removed_w_session_cleanup():
19992008
# create a clean session
20002009
session = bigframes.connect()
20012010

2002-
# create an unnamed remote function in the session
2003-
@session.remote_function(reuse=False, cloud_function_service_account="default")
2004-
def foo(x: int) -> int:
2005-
return x + 1
2011+
with warnings.catch_warnings(record=True) as record:
2012+
# create an unnamed remote function in the session.
2013+
# The type hints in this function's signature are redundant. The
2014+
# `input_types` and `output_type` arguments from remote_function
2015+
# decorator take precedence and will be used instead.
2016+
@session.remote_function(
2017+
input_types=[int],
2018+
output_type=int,
2019+
reuse=False,
2020+
cloud_function_service_account="default",
2021+
)
2022+
def foo(x: int) -> int:
2023+
return x + 1
2024+
2025+
# No following warning with only redundant type hints (no conflict).
2026+
input_type_warning = "Conflicting input types detected"
2027+
assert not any(input_type_warning in str(warning.message) for warning in record)
2028+
return_type_warning = "Conflicting return type detected"
2029+
assert not any(return_type_warning in str(warning.message) for warning in record)
20062030

20072031
# ensure that remote function artifacts are created
20082032
assert foo.bigframes_remote_function is not None

0 commit comments

Comments
 (0)