Skip to content

Commit 9af7130

Browse files
authored
refactor: Consolidate all _utils unit tests of bigframes function (#2006)
* refactor: Consolidate all _utils unit tests of bigframes function * fix assert format
1 parent 7d89d76 commit 9af7130

File tree

2 files changed

+172
-175
lines changed

2 files changed

+172
-175
lines changed

tests/unit/functions/test_remote_function_utils.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,180 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
16+
from unittest.mock import patch
17+
1518
import bigframes_vendored.constants as constants
1619
import pytest
1720

1821
from bigframes.functions import _utils, function_typing
1922

2023

24+
def test_get_updated_package_requirements_no_extra_package():
25+
"""Tests with no extra package."""
26+
result = _utils.get_updated_package_requirements(capture_references=False)
27+
28+
assert result is None
29+
30+
initial_packages = ["xgboost"]
31+
result = _utils.get_updated_package_requirements(
32+
initial_packages, capture_references=False
33+
)
34+
35+
assert result == initial_packages
36+
37+
38+
@patch("bigframes.functions._utils.numpy.__version__", "1.24.4")
39+
@patch("bigframes.functions._utils.pyarrow.__version__", "14.0.1")
40+
@patch("bigframes.functions._utils.pandas.__version__", "2.0.3")
41+
@patch("bigframes.functions._utils.cloudpickle.__version__", "2.2.1")
42+
def test_get_updated_package_requirements_is_row_processor_with_versions():
43+
"""Tests with is_row_processor=True and specific versions."""
44+
expected = [
45+
"cloudpickle==2.2.1",
46+
"numpy==1.24.4",
47+
"pandas==2.0.3",
48+
"pyarrow==14.0.1",
49+
]
50+
result = _utils.get_updated_package_requirements(is_row_processor=True)
51+
52+
assert result == expected
53+
54+
55+
@patch("bigframes.functions._utils.warnings.warn")
56+
@patch("bigframes.functions._utils.cloudpickle.__version__", "2.2.1")
57+
def test_get_updated_package_requirements_ignore_version(mock_warn):
58+
"""
59+
Tests with is_row_processor=True and ignore_package_version=True.
60+
Should add packages without versions and raise a warning.
61+
"""
62+
expected = ["cloudpickle==2.2.1", "numpy", "pandas", "pyarrow"]
63+
result = _utils.get_updated_package_requirements(
64+
is_row_processor=True, ignore_package_version=True
65+
)
66+
67+
assert result == expected
68+
# Verify that a warning was issued.
69+
mock_warn.assert_called_once()
70+
71+
72+
@patch("bigframes.functions._utils.numpy.__version__", "1.24.4")
73+
@patch("bigframes.functions._utils.pyarrow.__version__", "14.0.1")
74+
@patch("bigframes.functions._utils.pandas.__version__", "2.0.3")
75+
def test_get_updated_package_requirements_capture_references_false():
76+
"""
77+
Tests with capture_references=False.
78+
Should not add cloudpickle but should add others if requested.
79+
"""
80+
# Case 1: Only capture_references=False.
81+
result_1 = _utils.get_updated_package_requirements(capture_references=False)
82+
83+
assert result_1 is None
84+
85+
# Case 2: capture_references=False but is_row_processor=True.
86+
expected_2 = ["numpy==1.24.4", "pandas==2.0.3", "pyarrow==14.0.1"]
87+
result_2 = _utils.get_updated_package_requirements(
88+
is_row_processor=True, capture_references=False
89+
)
90+
91+
assert result_2 == expected_2
92+
93+
94+
@patch("bigframes.functions._utils.numpy.__version__", "1.24.4")
95+
@patch("bigframes.functions._utils.pyarrow.__version__", "14.0.1")
96+
@patch("bigframes.functions._utils.pandas.__version__", "2.0.3")
97+
@patch("bigframes.functions._utils.cloudpickle.__version__", "2.2.1")
98+
def test_get_updated_package_requirements_non_overlapping_packages():
99+
"""Tests providing an initial list of packages that do not overlap."""
100+
initial_packages = ["scikit-learn==1.3.0", "xgboost"]
101+
expected = [
102+
"cloudpickle==2.2.1",
103+
"numpy==1.24.4",
104+
"pandas==2.0.3",
105+
"pyarrow==14.0.1",
106+
"scikit-learn==1.3.0",
107+
"xgboost",
108+
]
109+
result = _utils.get_updated_package_requirements(
110+
package_requirements=initial_packages, is_row_processor=True
111+
)
112+
113+
assert result == expected
114+
115+
116+
@patch("bigframes.functions._utils.numpy.__version__", "1.24.4")
117+
@patch("bigframes.functions._utils.pyarrow.__version__", "14.0.1")
118+
@patch("bigframes.functions._utils.pandas.__version__", "2.0.3")
119+
@patch("bigframes.functions._utils.cloudpickle.__version__", "2.2.1")
120+
def test_get_updated_package_requirements_overlapping_packages():
121+
"""Tests that packages are not added if they already exist."""
122+
# The function should respect the pre-existing pandas version.
123+
initial_packages = ["pandas==1.5.3", "numpy"]
124+
expected = [
125+
"cloudpickle==2.2.1",
126+
"numpy",
127+
"pandas==1.5.3",
128+
"pyarrow==14.0.1",
129+
]
130+
result = _utils.get_updated_package_requirements(
131+
package_requirements=initial_packages, is_row_processor=True
132+
)
133+
134+
assert result == expected
135+
136+
137+
@patch("bigframes.functions._utils.cloudpickle.__version__", "2.2.1")
138+
def test_get_updated_package_requirements_with_existing_cloudpickle():
139+
"""Tests that cloudpickle is not added if it already exists."""
140+
initial_packages = ["cloudpickle==2.0.0"]
141+
expected = ["cloudpickle==2.0.0"]
142+
result = _utils.get_updated_package_requirements(
143+
package_requirements=initial_packages
144+
)
145+
146+
assert result == expected
147+
148+
149+
def test_package_existed_helper():
150+
"""Tests the _package_existed helper function directly."""
151+
reqs = ["pandas==1.0", "numpy", "scikit-learn>=1.2.0"]
152+
153+
# Exact match
154+
assert _utils._package_existed(reqs, "pandas==1.0")
155+
# Different version
156+
assert _utils._package_existed(reqs, "pandas==2.0")
157+
# No version specified
158+
assert _utils._package_existed(reqs, "numpy")
159+
# Not in list
160+
assert not _utils._package_existed(reqs, "xgboost")
161+
# Empty list
162+
assert not _utils._package_existed([], "pandas")
163+
164+
165+
def test_has_conflict_output_type_no_conflict():
166+
"""Tests has_conflict_output_type with type annotation."""
167+
# Helper functions with type annotation for has_conflict_output_type.
168+
def _func_with_return_type(x: int) -> int:
169+
return x
170+
171+
signature = inspect.signature(_func_with_return_type)
172+
173+
assert _utils.has_conflict_output_type(signature, output_type=float)
174+
assert not _utils.has_conflict_output_type(signature, output_type=int)
175+
176+
177+
def test_has_conflict_output_type_no_annotation():
178+
"""Tests has_conflict_output_type without type annotation."""
179+
# Helper functions without type annotation for has_conflict_output_type.
180+
def _func_without_return_type(x):
181+
return x
182+
183+
signature = inspect.signature(_func_without_return_type)
184+
185+
assert not _utils.has_conflict_output_type(signature, output_type=int)
186+
assert not _utils.has_conflict_output_type(signature, output_type=float)
187+
188+
21189
@pytest.mark.parametrize(
22190
["metadata_options", "metadata_string"],
23191
(
@@ -54,6 +222,7 @@
54222
),
55223
)
56224
def test_get_bigframes_metadata(metadata_options, metadata_string):
225+
57226
assert _utils.get_bigframes_metadata(**metadata_options) == metadata_string
58227

59228

@@ -72,6 +241,7 @@ def test_get_bigframes_metadata(metadata_options, metadata_string):
72241
def test_get_bigframes_metadata_array_type_not_serializable(output_type):
73242
with pytest.raises(ValueError) as context:
74243
_utils.get_bigframes_metadata(python_output_type=output_type)
244+
75245
assert str(context.value) == (
76246
f"python_output_type {output_type} is not serializable. {constants.FEEDBACK_LINK}"
77247
)
@@ -125,6 +295,7 @@ def test_get_bigframes_metadata_array_type_not_serializable(output_type):
125295
def test_get_python_output_type_from_bigframes_metadata(
126296
metadata_string, python_output_type
127297
):
298+
128299
assert (
129300
_utils.get_python_output_type_from_bigframes_metadata(metadata_string)
130301
== python_output_type
@@ -135,4 +306,5 @@ def test_metadata_roundtrip_supported_array_types():
135306
for array_of in function_typing.RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES:
136307
ser = _utils.get_bigframes_metadata(python_output_type=list[array_of]) # type: ignore
137308
deser = _utils.get_python_output_type_from_bigframes_metadata(ser)
309+
138310
assert deser == list[array_of] # type: ignore

tests/unit/functions/test_utils.py

Lines changed: 0 additions & 175 deletions
This file was deleted.

0 commit comments

Comments
 (0)