1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import inspect
16+ from unittest .mock import patch
17+
1518import bigframes_vendored .constants as constants
1619import pytest
1720
1821from bigframes .functions import _utils , function_typing
1922
2023
24+ def test_get_updated_package_requirements_no_extra_package ():
25+ """Tests with no extra package."""
26+ result = _utils .get_updated_package_requirements (capture_references = False )
27+
28+ assert result is None
29+
30+ initial_packages = ["xgboost" ]
31+ result = _utils .get_updated_package_requirements (
32+ initial_packages , capture_references = False
33+ )
34+
35+ assert result == initial_packages
36+
37+
38+ @patch ("bigframes.functions._utils.numpy.__version__" , "1.24.4" )
39+ @patch ("bigframes.functions._utils.pyarrow.__version__" , "14.0.1" )
40+ @patch ("bigframes.functions._utils.pandas.__version__" , "2.0.3" )
41+ @patch ("bigframes.functions._utils.cloudpickle.__version__" , "2.2.1" )
42+ def test_get_updated_package_requirements_is_row_processor_with_versions ():
43+ """Tests with is_row_processor=True and specific versions."""
44+ expected = [
45+ "cloudpickle==2.2.1" ,
46+ "numpy==1.24.4" ,
47+ "pandas==2.0.3" ,
48+ "pyarrow==14.0.1" ,
49+ ]
50+ result = _utils .get_updated_package_requirements (is_row_processor = True )
51+
52+ assert result == expected
53+
54+
55+ @patch ("bigframes.functions._utils.warnings.warn" )
56+ @patch ("bigframes.functions._utils.cloudpickle.__version__" , "2.2.1" )
57+ def test_get_updated_package_requirements_ignore_version (mock_warn ):
58+ """
59+ Tests with is_row_processor=True and ignore_package_version=True.
60+ Should add packages without versions and raise a warning.
61+ """
62+ expected = ["cloudpickle==2.2.1" , "numpy" , "pandas" , "pyarrow" ]
63+ result = _utils .get_updated_package_requirements (
64+ is_row_processor = True , ignore_package_version = True
65+ )
66+
67+ assert result == expected
68+ # Verify that a warning was issued.
69+ mock_warn .assert_called_once ()
70+
71+
72+ @patch ("bigframes.functions._utils.numpy.__version__" , "1.24.4" )
73+ @patch ("bigframes.functions._utils.pyarrow.__version__" , "14.0.1" )
74+ @patch ("bigframes.functions._utils.pandas.__version__" , "2.0.3" )
75+ def test_get_updated_package_requirements_capture_references_false ():
76+ """
77+ Tests with capture_references=False.
78+ Should not add cloudpickle but should add others if requested.
79+ """
80+ # Case 1: Only capture_references=False.
81+ result_1 = _utils .get_updated_package_requirements (capture_references = False )
82+
83+ assert result_1 is None
84+
85+ # Case 2: capture_references=False but is_row_processor=True.
86+ expected_2 = ["numpy==1.24.4" , "pandas==2.0.3" , "pyarrow==14.0.1" ]
87+ result_2 = _utils .get_updated_package_requirements (
88+ is_row_processor = True , capture_references = False
89+ )
90+
91+ assert result_2 == expected_2
92+
93+
94+ @patch ("bigframes.functions._utils.numpy.__version__" , "1.24.4" )
95+ @patch ("bigframes.functions._utils.pyarrow.__version__" , "14.0.1" )
96+ @patch ("bigframes.functions._utils.pandas.__version__" , "2.0.3" )
97+ @patch ("bigframes.functions._utils.cloudpickle.__version__" , "2.2.1" )
98+ def test_get_updated_package_requirements_non_overlapping_packages ():
99+ """Tests providing an initial list of packages that do not overlap."""
100+ initial_packages = ["scikit-learn==1.3.0" , "xgboost" ]
101+ expected = [
102+ "cloudpickle==2.2.1" ,
103+ "numpy==1.24.4" ,
104+ "pandas==2.0.3" ,
105+ "pyarrow==14.0.1" ,
106+ "scikit-learn==1.3.0" ,
107+ "xgboost" ,
108+ ]
109+ result = _utils .get_updated_package_requirements (
110+ package_requirements = initial_packages , is_row_processor = True
111+ )
112+
113+ assert result == expected
114+
115+
116+ @patch ("bigframes.functions._utils.numpy.__version__" , "1.24.4" )
117+ @patch ("bigframes.functions._utils.pyarrow.__version__" , "14.0.1" )
118+ @patch ("bigframes.functions._utils.pandas.__version__" , "2.0.3" )
119+ @patch ("bigframes.functions._utils.cloudpickle.__version__" , "2.2.1" )
120+ def test_get_updated_package_requirements_overlapping_packages ():
121+ """Tests that packages are not added if they already exist."""
122+ # The function should respect the pre-existing pandas version.
123+ initial_packages = ["pandas==1.5.3" , "numpy" ]
124+ expected = [
125+ "cloudpickle==2.2.1" ,
126+ "numpy" ,
127+ "pandas==1.5.3" ,
128+ "pyarrow==14.0.1" ,
129+ ]
130+ result = _utils .get_updated_package_requirements (
131+ package_requirements = initial_packages , is_row_processor = True
132+ )
133+
134+ assert result == expected
135+
136+
137+ @patch ("bigframes.functions._utils.cloudpickle.__version__" , "2.2.1" )
138+ def test_get_updated_package_requirements_with_existing_cloudpickle ():
139+ """Tests that cloudpickle is not added if it already exists."""
140+ initial_packages = ["cloudpickle==2.0.0" ]
141+ expected = ["cloudpickle==2.0.0" ]
142+ result = _utils .get_updated_package_requirements (
143+ package_requirements = initial_packages
144+ )
145+
146+ assert result == expected
147+
148+
149+ def test_package_existed_helper ():
150+ """Tests the _package_existed helper function directly."""
151+ reqs = ["pandas==1.0" , "numpy" , "scikit-learn>=1.2.0" ]
152+
153+ # Exact match
154+ assert _utils ._package_existed (reqs , "pandas==1.0" )
155+ # Different version
156+ assert _utils ._package_existed (reqs , "pandas==2.0" )
157+ # No version specified
158+ assert _utils ._package_existed (reqs , "numpy" )
159+ # Not in list
160+ assert not _utils ._package_existed (reqs , "xgboost" )
161+ # Empty list
162+ assert not _utils ._package_existed ([], "pandas" )
163+
164+
165+ def test_has_conflict_output_type_no_conflict ():
166+ """Tests has_conflict_output_type with type annotation."""
167+ # Helper functions with type annotation for has_conflict_output_type.
168+ def _func_with_return_type (x : int ) -> int :
169+ return x
170+
171+ signature = inspect .signature (_func_with_return_type )
172+
173+ assert _utils .has_conflict_output_type (signature , output_type = float )
174+ assert not _utils .has_conflict_output_type (signature , output_type = int )
175+
176+
177+ def test_has_conflict_output_type_no_annotation ():
178+ """Tests has_conflict_output_type without type annotation."""
179+ # Helper functions without type annotation for has_conflict_output_type.
180+ def _func_without_return_type (x ):
181+ return x
182+
183+ signature = inspect .signature (_func_without_return_type )
184+
185+ assert not _utils .has_conflict_output_type (signature , output_type = int )
186+ assert not _utils .has_conflict_output_type (signature , output_type = float )
187+
188+
21189@pytest .mark .parametrize (
22190 ["metadata_options" , "metadata_string" ],
23191 (
54222 ),
55223)
56224def test_get_bigframes_metadata (metadata_options , metadata_string ):
225+
57226 assert _utils .get_bigframes_metadata (** metadata_options ) == metadata_string
58227
59228
@@ -72,6 +241,7 @@ def test_get_bigframes_metadata(metadata_options, metadata_string):
72241def test_get_bigframes_metadata_array_type_not_serializable (output_type ):
73242 with pytest .raises (ValueError ) as context :
74243 _utils .get_bigframes_metadata (python_output_type = output_type )
244+
75245 assert str (context .value ) == (
76246 f"python_output_type { output_type } is not serializable. { constants .FEEDBACK_LINK } "
77247 )
@@ -125,6 +295,7 @@ def test_get_bigframes_metadata_array_type_not_serializable(output_type):
125295def test_get_python_output_type_from_bigframes_metadata (
126296 metadata_string , python_output_type
127297):
298+
128299 assert (
129300 _utils .get_python_output_type_from_bigframes_metadata (metadata_string )
130301 == python_output_type
@@ -135,4 +306,5 @@ def test_metadata_roundtrip_supported_array_types():
135306 for array_of in function_typing .RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES :
136307 ser = _utils .get_bigframes_metadata (python_output_type = list [array_of ]) # type: ignore
137308 deser = _utils .get_python_output_type_from_bigframes_metadata (ser )
309+
138310 assert deser == list [array_of ] # type: ignore
0 commit comments