Skip to content

Commit d3997ba

Browse files
committed
fix unit tests
1 parent 9362838 commit d3997ba

File tree

2 files changed

+84
-27
lines changed

2 files changed

+84
-27
lines changed

bigframes/ml/core.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def ai_forecast(
4545
result_sql = self._sql_generator.ai_forecast(
4646
source_sql=input_data.sql, options=options
4747
)
48-
return self._session.read_gbq(result_sql)
48+
49+
# TODO(b/395912450): Once the limitations with local data are
50+
# resolved, consider setting allow_large_results only when expected
51+
# data size is large.
52+
return self._session.read_gbq_query(result_sql, allow_large_results=True)
4953

5054

5155
class BqmlModel(BaseBqml):
@@ -169,7 +173,10 @@ def explain_predict(
169173
def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
170174
sql = self._sql_generator.ml_global_explain(struct_options=options)
171175
return (
172-
self._session.read_gbq(sql)
176+
# TODO(b/395912450): Once the limitations with local data are
177+
# resolved, consider setting allow_large_results only when expected
178+
# data size is large.
179+
self._session.read_gbq_query(sql, allow_large_results=True)
173180
.sort_values(by="attribution", ascending=False)
174181
.set_index("feature")
175182
)
@@ -244,26 +251,49 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
244251
sql = self._sql_generator.ml_forecast(struct_options=options)
245252
timestamp_col_name = "forecast_timestamp"
246253
index_cols = [timestamp_col_name]
247-
first_col_name = self._session.read_gbq(sql).columns.values[0]
254+
# TODO(b/395912450): Once the limitations with local data are
255+
# resolved, consider setting allow_large_results only when expected
256+
# data size is large.
257+
first_col_name = self._session.read_gbq_query(
258+
sql, allow_large_results=True
259+
).columns.values[0]
248260
if timestamp_col_name != first_col_name:
249261
index_cols.append(first_col_name)
250-
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
262+
# TODO(b/395912450): Once the limitations with local data are
263+
# resolved, consider setting allow_large_results only when expected
264+
# data size is large.
265+
return self._session.read_gbq_query(
266+
sql, index_col=index_cols, allow_large_results=True
267+
).reset_index()
251268

252269
def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
253270
sql = self._sql_generator.ml_explain_forecast(struct_options=options)
254271
timestamp_col_name = "time_series_timestamp"
255272
index_cols = [timestamp_col_name]
256-
first_col_name = self._session.read_gbq(sql).columns.values[0]
273+
# TODO(b/395912450): Once the limitations with local data are
274+
# resolved, consider setting allow_large_results only when expected
275+
# data size is large.
276+
first_col_name = self._session.read_gbq_query(
277+
sql, allow_large_results=True
278+
).columns.values[0]
257279
if timestamp_col_name != first_col_name:
258280
index_cols.append(first_col_name)
259-
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
281+
# TODO(b/395912450): Once the limitations with local data are
282+
# resolved, consider setting allow_large_results only when expected
283+
# data size is large.
284+
return self._session.read_gbq_query(
285+
sql, index_col=index_cols, allow_large_results=True
286+
).reset_index()
260287

261288
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
262289
sql = self._sql_generator.ml_evaluate(
263290
input_data.sql if (input_data is not None) else None
264291
)
265292

266-
return self._session.read_gbq(sql)
293+
# TODO(b/395912450): Once the limitations with local data are
294+
# resolved, consider setting allow_large_results only when expected
295+
# data size is large.
296+
return self._session.read_gbq_query(sql, allow_large_results=True)
267297

268298
def llm_evaluate(
269299
self,
@@ -272,42 +302,62 @@ def llm_evaluate(
272302
):
273303
sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type)
274304

275-
return self._session.read_gbq(sql)
305+
# TODO(b/395912450): Once the limitations with local data are
306+
# resolved, consider setting allow_large_results only when expected
307+
# data size is large.
308+
return self._session.read_gbq_query(sql, allow_large_results=True)
276309

277310
def arima_evaluate(self, show_all_candidate_models: bool = False):
278311
sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models)
279312

280-
return self._session.read_gbq(sql)
313+
# TODO(b/395912450): Once the limitations with local data are
314+
# resolved, consider setting allow_large_results only when expected
315+
# data size is large.
316+
return self._session.read_gbq_query(sql, allow_large_results=True)
281317

282318
def arima_coefficients(self) -> bpd.DataFrame:
283319
sql = self._sql_generator.ml_arima_coefficients()
284320

285-
return self._session.read_gbq(sql)
321+
# TODO(b/395912450): Once the limitations with local data are
322+
# resolved, consider setting allow_large_results only when expected
323+
# data size is large.
324+
return self._session.read_gbq_query(sql, allow_large_results=True)
286325

287326
def centroids(self) -> bpd.DataFrame:
288327
assert self._model.model_type == "KMEANS"
289328

290329
sql = self._sql_generator.ml_centroids()
291330

292-
return self._session.read_gbq(
293-
sql, index_col=["centroid_id", "feature"]
331+
# TODO(b/395912450): Once the limitations with local data are
332+
# resolved, consider setting allow_large_results only when expected
333+
# data size is large.
334+
return self._session.read_gbq_query(
335+
sql, index_col=["centroid_id", "feature"], allow_large_results=True
294336
).reset_index()
295337

296338
def principal_components(self) -> bpd.DataFrame:
297339
assert self._model.model_type == "PCA"
298340

299341
sql = self._sql_generator.ml_principal_components()
300342

301-
return self._session.read_gbq(
302-
sql, index_col=["principal_component_id", "feature"]
343+
# TODO(b/395912450): Once the limitations with local data are
344+
# resolved, consider setting allow_large_results only when expected
345+
# data size is large.
346+
return self._session.read_gbq_query(
347+
sql,
348+
index_col=["principal_component_id", "feature"],
349+
allow_large_results=True,
303350
).reset_index()
304351

305352
def principal_component_info(self) -> bpd.DataFrame:
306353
assert self._model.model_type == "PCA"
307354

308355
sql = self._sql_generator.ml_principal_component_info()
309356

310-
return self._session.read_gbq(sql)
357+
# TODO(b/395912450): Once the limitations with local data are
358+
# resolved, consider setting allow_large_results only when expected
359+
# data size is large.
360+
return self._session.read_gbq_query(sql, allow_large_results=True)
311361

312362
def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
313363
job_config = self._session._prepare_copy_job_config()

tests/unit/ml/test_golden_sql.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ def test_linear_regression_predict(mock_session, bqml_model, mock_X):
143143
model._bqml_model = bqml_model
144144
model.predict(mock_X)
145145

146-
mock_session.read_gbq.assert_called_once_with(
146+
mock_session.read_gbq_query.assert_called_once_with(
147147
"SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))",
148148
index_col=["index_column_id"],
149+
allow_large_results=True,
149150
)
150151

151152

@@ -154,8 +155,9 @@ def test_linear_regression_score(mock_session, bqml_model, mock_X, mock_y):
154155
model._bqml_model = bqml_model
155156
model.score(mock_X, mock_y)
156157

157-
mock_session.read_gbq.assert_called_once_with(
158-
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))"
158+
mock_session.read_gbq_query.assert_called_once_with(
159+
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))",
160+
allow_large_results=True,
159161
)
160162

161163

@@ -167,7 +169,7 @@ def test_logistic_regression_default_fit(
167169
model.fit(mock_X, mock_y)
168170

169171
mock_session._start_query_ml_ddl.assert_called_once_with(
170-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
172+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql",
171173
)
172174

173175

@@ -198,9 +200,10 @@ def test_logistic_regression_predict(mock_session, bqml_model, mock_X):
198200
model._bqml_model = bqml_model
199201
model.predict(mock_X)
200202

201-
mock_session.read_gbq.assert_called_once_with(
203+
mock_session.read_gbq_query.assert_called_once_with(
202204
"SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))",
203205
index_col=["index_column_id"],
206+
allow_large_results=True,
204207
)
205208

206209

@@ -209,8 +212,9 @@ def test_logistic_regression_score(mock_session, bqml_model, mock_X, mock_y):
209212
model._bqml_model = bqml_model
210213
model.score(mock_X, mock_y)
211214

212-
mock_session.read_gbq.assert_called_once_with(
213-
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))"
215+
mock_session.read_gbq_query.assert_called_once_with(
216+
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))",
217+
allow_large_results=True,
214218
)
215219

216220

@@ -243,9 +247,10 @@ def test_decomposition_mf_predict(mock_session, bqml_model, mock_X):
243247
model._bqml_model = bqml_model
244248
model.predict(mock_X)
245249

246-
mock_session.read_gbq.assert_called_once_with(
250+
mock_session.read_gbq_query.assert_called_once_with(
247251
"SELECT * FROM ML.RECOMMEND(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))",
248252
index_col=["index_column_id"],
253+
allow_large_results=True,
249254
)
250255

251256

@@ -260,8 +265,9 @@ def test_decomposition_mf_score(mock_session, bqml_model):
260265
)
261266
model._bqml_model = bqml_model
262267
model.score()
263-
mock_session.read_gbq.assert_called_once_with(
264-
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)"
268+
mock_session.read_gbq_query.assert_called_once_with(
269+
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)",
270+
allow_large_results=True,
265271
)
266272

267273

@@ -276,6 +282,7 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X):
276282
)
277283
model._bqml_model = bqml_model
278284
model.score(mock_X)
279-
mock_session.read_gbq.assert_called_once_with(
280-
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))"
285+
mock_session.read_gbq_query.assert_called_once_with(
286+
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))",
287+
allow_large_results=True,
281288
)

0 commit comments

Comments
 (0)