Skip to content

Commit 84f4427

Browse files
committed
add less mocking
1 parent 1d88694 commit 84f4427

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

tests/unit/bigquery/test_ml.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def test_to_sql_with_pandas_dataframe(read_pandas_mock):
6565
read_pandas_mock.assert_called_once()
6666

6767

68+
@mock.patch("bigframes.bigquery._operations.ml._get_model_metadata")
6869
@mock.patch("bigframes.pandas.read_pandas")
69-
@mock.patch("bigframes.core.sql.ml.create_model_ddl")
7070
def test_create_model_with_pandas_dataframe(
71-
create_model_ddl_mock, read_pandas_mock, mock_session
71+
read_pandas_mock, _get_model_metadata_mock, mock_session
7272
):
7373
df = pd.DataFrame({"col1": [1, 2, 3]})
7474
read_pandas_mock.return_value._to_sql_query.return_value = (
@@ -78,72 +78,70 @@ def test_create_model_with_pandas_dataframe(
7878
)
7979
ml_ops.create_model("model_name", training_data=df, session=mock_session)
8080
read_pandas_mock.assert_called_once()
81-
create_model_ddl_mock.assert_called_once()
81+
mock_session.read_gbq_query.assert_called_once()
82+
generated_sql = mock_session.read_gbq_query.call_args[0][0]
83+
assert "CREATE MODEL `model_name`" in generated_sql
84+
assert "AS SELECT * FROM `pandas_df`" in generated_sql
8285

8386

8487
@mock.patch("bigframes.pandas.read_gbq_query")
8588
@mock.patch("bigframes.pandas.read_pandas")
86-
@mock.patch("bigframes.core.sql.ml.evaluate")
87-
def test_evaluate_with_pandas_dataframe(
88-
evaluate_mock, read_pandas_mock, read_gbq_query_mock
89-
):
89+
def test_evaluate_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
9090
df = pd.DataFrame({"col1": [1, 2, 3]})
9191
read_pandas_mock.return_value._to_sql_query.return_value = (
9292
"SELECT * FROM `pandas_df`",
9393
[],
9494
[],
9595
)
96-
evaluate_mock.return_value = "SELECT * FROM `pandas_df`"
9796
ml_ops.evaluate(MODEL_SERIES, input_=df)
9897
read_pandas_mock.assert_called_once()
99-
evaluate_mock.assert_called_once()
100-
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
98+
read_gbq_query_mock.assert_called_once()
99+
generated_sql = read_gbq_query_mock.call_args[0][0]
100+
assert "ML.EVALUATE" in generated_sql
101+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
102+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
101103

102104

103105
@mock.patch("bigframes.pandas.read_gbq_query")
104106
@mock.patch("bigframes.pandas.read_pandas")
105-
@mock.patch("bigframes.core.sql.ml.predict")
106-
def test_predict_with_pandas_dataframe(
107-
predict_mock, read_pandas_mock, read_gbq_query_mock
108-
):
107+
def test_predict_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
109108
df = pd.DataFrame({"col1": [1, 2, 3]})
110109
read_pandas_mock.return_value._to_sql_query.return_value = (
111110
"SELECT * FROM `pandas_df`",
112111
[],
113112
[],
114113
)
115-
predict_mock.return_value = "SELECT * FROM `pandas_df`"
116114
ml_ops.predict(MODEL_SERIES, input_=df)
117115
read_pandas_mock.assert_called_once()
118-
predict_mock.assert_called_once()
119-
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
116+
read_gbq_query_mock.assert_called_once()
117+
generated_sql = read_gbq_query_mock.call_args[0][0]
118+
assert "ML.PREDICT" in generated_sql
119+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
120+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
120121

121122

122123
@mock.patch("bigframes.pandas.read_gbq_query")
123124
@mock.patch("bigframes.pandas.read_pandas")
124-
@mock.patch("bigframes.core.sql.ml.explain_predict")
125-
def test_explain_predict_with_pandas_dataframe(
126-
explain_predict_mock, read_pandas_mock, read_gbq_query_mock
127-
):
125+
def test_explain_predict_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
128126
df = pd.DataFrame({"col1": [1, 2, 3]})
129127
read_pandas_mock.return_value._to_sql_query.return_value = (
130128
"SELECT * FROM `pandas_df`",
131129
[],
132130
[],
133131
)
134-
explain_predict_mock.return_value = "SELECT * FROM `pandas_df`"
135132
ml_ops.explain_predict(MODEL_SERIES, input_=df)
136133
read_pandas_mock.assert_called_once()
137-
explain_predict_mock.assert_called_once()
138-
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
134+
read_gbq_query_mock.assert_called_once()
135+
generated_sql = read_gbq_query_mock.call_args[0][0]
136+
assert "ML.EXPLAIN_PREDICT" in generated_sql
137+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
138+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
139139

140140

141141
@mock.patch("bigframes.pandas.read_gbq_query")
142-
@mock.patch("bigframes.core.sql.ml.global_explain")
143-
def test_global_explain_with_pandas_series_model(
144-
global_explain_mock, read_gbq_query_mock
145-
):
146-
global_explain_mock.return_value = "SELECT * FROM `pandas_df`"
142+
def test_global_explain_with_pandas_series_model(read_gbq_query_mock):
147143
ml_ops.global_explain(MODEL_SERIES)
148-
global_explain_mock.assert_called_once()
149-
read_gbq_query_mock.assert_called_once_with("SELECT * FROM `pandas_df`")
144+
read_gbq_query_mock.assert_called_once()
145+
generated_sql = read_gbq_query_mock.call_args[0][0]
146+
assert "ML.GLOBAL_EXPLAIN" in generated_sql
147+
assert f"MODEL `{MODEL_NAME}`" in generated_sql

0 commit comments

Comments
 (0)