Skip to content

Commit 1d88694

Browse files
committed
add unit tests
1 parent 82d1aec commit 1d88694

File tree

2 files changed

+154
-2
lines changed

2 files changed

+154
-2
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,11 @@ def _get_model_name_and_session(
4949
*dataframes: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]],
5050
) -> tuple[str, Optional[bigframes.session.Session]]:
5151
if isinstance(model, pd.Series):
52-
model_ref = model["modelReference"]
53-
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
52+
try:
53+
model_ref = model["modelReference"]
54+
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
55+
except KeyError:
56+
raise ValueError("modelReference must be present in the pandas Series.")
5457
elif isinstance(model, str):
5558
model_name = model
5659
else:

tests/unit/bigquery/test_ml.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from unittest import mock
17+
18+
import pandas as pd
19+
import pytest
20+
21+
import bigframes.bigquery._operations.ml as ml_ops
22+
import bigframes.session
23+
24+
25+
@pytest.fixture
26+
def mock_session():
27+
return mock.create_autospec(spec=bigframes.session.Session)
28+
29+
30+
MODEL_SERIES = pd.Series(
31+
{
32+
"modelReference": {
33+
"projectId": "test-project",
34+
"datasetId": "test-dataset",
35+
"modelId": "test-model",
36+
}
37+
}
38+
)
39+
40+
MODEL_NAME = "test-project.test-dataset.test-model"
41+
42+
43+
def test_get_model_name_and_session_with_pandas_series_model_input():
44+
model_name, _ = ml_ops._get_model_name_and_session(MODEL_SERIES)
45+
assert model_name == MODEL_NAME
46+
47+
48+
def test_get_model_name_and_session_with_pandas_series_model_input_missing_model_reference():
49+
model_series = pd.Series({"some_other_key": "value"})
50+
with pytest.raises(
51+
ValueError, match="modelReference must be present in the pandas Series"
52+
):
53+
ml_ops._get_model_name_and_session(model_series)
54+
55+
56+
@mock.patch("bigframes.pandas.read_pandas")
57+
def test_to_sql_with_pandas_dataframe(read_pandas_mock):
58+
df = pd.DataFrame({"col1": [1, 2, 3]})
59+
read_pandas_mock.return_value._to_sql_query.return_value = (
60+
"SELECT * FROM `pandas_df`",
61+
[],
62+
[],
63+
)
64+
ml_ops._to_sql(df)
65+
read_pandas_mock.assert_called_once()
66+
67+
68+
@mock.patch("bigframes.pandas.read_pandas")
69+
@mock.patch("bigframes.core.sql.ml.create_model_ddl")
70+
def test_create_model_with_pandas_dataframe(
71+
create_model_ddl_mock, read_pandas_mock, mock_session
72+
):
73+
df = pd.DataFrame({"col1": [1, 2, 3]})
74+
read_pandas_mock.return_value._to_sql_query.return_value = (
75+
"SELECT * FROM `pandas_df`",
76+
[],
77+
[],
78+
)
79+
ml_ops.create_model("model_name", training_data=df, session=mock_session)
80+
read_pandas_mock.assert_called_once()
81+
create_model_ddl_mock.assert_called_once()
82+
83+
84+
@mock.patch("bigframes.pandas.read_gbq_query")
85+
@mock.patch("bigframes.pandas.read_pandas")
86+
@mock.patch("bigframes.core.sql.ml.evaluate")
87+
def test_evaluate_with_pandas_dataframe(
88+
evaluate_mock, read_pandas_mock, read_gbq_query_mock
89+
):
90+
df = pd.DataFrame({"col1": [1, 2, 3]})
91+
read_pandas_mock.return_value._to_sql_query.return_value = (
92+
"SELECT * FROM `pandas_df`",
93+
[],
94+
[],
95+
)
96+
evaluate_mock.return_value = "SELECT * FROM `pandas_df`"
97+
ml_ops.evaluate(MODEL_SERIES, input_=df)
98+
read_pandas_mock.assert_called_once()
99+
evaluate_mock.assert_called_once()
100+
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
101+
102+
103+
@mock.patch("bigframes.pandas.read_gbq_query")
104+
@mock.patch("bigframes.pandas.read_pandas")
105+
@mock.patch("bigframes.core.sql.ml.predict")
106+
def test_predict_with_pandas_dataframe(
107+
predict_mock, read_pandas_mock, read_gbq_query_mock
108+
):
109+
df = pd.DataFrame({"col1": [1, 2, 3]})
110+
read_pandas_mock.return_value._to_sql_query.return_value = (
111+
"SELECT * FROM `pandas_df`",
112+
[],
113+
[],
114+
)
115+
predict_mock.return_value = "SELECT * FROM `pandas_df`"
116+
ml_ops.predict(MODEL_SERIES, input_=df)
117+
read_pandas_mock.assert_called_once()
118+
predict_mock.assert_called_once()
119+
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
120+
121+
122+
@mock.patch("bigframes.pandas.read_gbq_query")
123+
@mock.patch("bigframes.pandas.read_pandas")
124+
@mock.patch("bigframes.core.sql.ml.explain_predict")
125+
def test_explain_predict_with_pandas_dataframe(
126+
explain_predict_mock, read_pandas_mock, read_gbq_query_mock
127+
):
128+
df = pd.DataFrame({"col1": [1, 2, 3]})
129+
read_pandas_mock.return_value._to_sql_query.return_value = (
130+
"SELECT * FROM `pandas_df`",
131+
[],
132+
[],
133+
)
134+
explain_predict_mock.return_value = "SELECT * FROM `pandas_df`"
135+
ml_ops.explain_predict(MODEL_SERIES, input_=df)
136+
read_pandas_mock.assert_called_once()
137+
explain_predict_mock.assert_called_once()
138+
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
139+
140+
141+
@mock.patch("bigframes.pandas.read_gbq_query")
142+
@mock.patch("bigframes.core.sql.ml.global_explain")
143+
def test_global_explain_with_pandas_series_model(
144+
global_explain_mock, read_gbq_query_mock
145+
):
146+
global_explain_mock.return_value = "SELECT * FROM `pandas_df`"
147+
ml_ops.global_explain(MODEL_SERIES)
148+
global_explain_mock.assert_called_once()
149+
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")

0 commit comments

Comments
 (0)