Skip to content

Commit bcbc732

Browse files
authored
feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns (#1222)
* feat: add LogisticRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns * update tests
1 parent 684b2a6 commit bcbc732

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

bigframes/ml/linear_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,34 @@ def predict(
353353

354354
return self._bqml_model.predict(X)
355355

356+
def predict_explain(
357+
self,
358+
X: utils.ArrayType,
359+
) -> bpd.DataFrame:
360+
"""
361+
Explain predictions for a logistic regression model.
362+
363+
.. note::
364+
Output matches that of the BigQuery ML.EXPLAIN_PREDICT function.
365+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict
366+
367+
Args:
368+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
369+
pandas.core.frame.DataFrame or pandas.core.series.Series):
370+
Series or a DataFrame to explain its predictions.
371+
372+
Returns:
373+
bigframes.pandas.DataFrame:
374+
The predicted DataFrames with explanation columns.
375+
"""
376+
# TODO(b/377366612): Add support for `top_k_features` parameter
377+
if not self._bqml_model:
378+
raise RuntimeError("A model must be fitted before predict")
379+
380+
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
381+
382+
return self._bqml_model.explain_predict(X)
383+
356384
def score(
357385
self,
358386
X: utils.ArrayType,

tests/system/small/ml/test_linear_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,50 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df):
307307
)
308308

309309

310+
def test_logistic_model_predict_params(
311+
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
312+
):
313+
predictions = penguins_logistic_model.predict(new_penguins_df).to_pandas()
314+
assert predictions.shape[0] >= 1
315+
prediction_columns = set(predictions.columns)
316+
expected_columns = {
317+
"predicted_sex",
318+
"predicted_sex_probs",
319+
"species",
320+
"island",
321+
"culmen_length_mm",
322+
"culmen_depth_mm",
323+
"flipper_length_mm",
324+
"body_mass_g",
325+
"sex",
326+
}
327+
assert expected_columns <= prediction_columns
328+
329+
330+
def test_logistic_model_predict_explain_params(
331+
penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df
332+
):
333+
predictions = penguins_logistic_model.predict_explain(new_penguins_df).to_pandas()
334+
assert predictions.shape[0] >= 1
335+
prediction_columns = set(predictions.columns)
336+
expected_columns = {
337+
"predicted_sex",
338+
"probability",
339+
"top_feature_attributions",
340+
"baseline_prediction_value",
341+
"prediction_value",
342+
"approximation_error",
343+
"species",
344+
"island",
345+
"culmen_length_mm",
346+
"culmen_depth_mm",
347+
"flipper_length_mm",
348+
"body_mass_g",
349+
"sex",
350+
}
351+
assert expected_columns <= prediction_columns
352+
353+
310354
def test_logistic_model_to_gbq_saved_score(
311355
penguins_logistic_model, table_id_unique, penguins_df_default_index
312356
):

0 commit comments

Comments
 (0)