Skip to content

Commit e13eca2

Browse files
authored
feat: add LinearRegression.predict_explain() to generate ML.EXPLAIN_PREDICT columns (#1190)
* feat: add LinearRegression.predict_explain to generate predict explain columns * add test cases * add test case * update predict_explain * update the test * Add sql and core tests * fix docs error * add TODO comment to support method paramaters * update the test parmametr of linear model * update test to fix failing checks
1 parent 14f24ca commit e13eca2

File tree

6 files changed

+174
-0
lines changed

6 files changed

+174
-0
lines changed

bigframes/ml/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
123123
self._model_manipulation_sql_generator.ml_predict,
124124
)
125125

126+
def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
127+
return self._apply_ml_tvf(
128+
input_data,
129+
self._model_manipulation_sql_generator.ml_explain_predict,
130+
)
131+
126132
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
127133
return self._apply_ml_tvf(
128134
input_data,

bigframes/ml/linear_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,34 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
160160

161161
return self._bqml_model.predict(X)
162162

163+
def predict_explain(
164+
self,
165+
X: utils.ArrayType,
166+
) -> bpd.DataFrame:
167+
"""
168+
Explain predictions for a linear regression model.
169+
170+
.. note::
171+
Output matches that of the BigQuery ML.EXPLAIN_PREDICT function.
172+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict
173+
174+
Args:
175+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
176+
pandas.core.frame.DataFrame or pandas.core.series.Series):
177+
Series or a DataFrame to explain its predictions.
178+
179+
Returns:
180+
bigframes.pandas.DataFrame:
181+
The predicted DataFrames with explanation columns.
182+
"""
183+
# TODO(b/377366612): Add support for `top_k_features` parameter
184+
if not self._bqml_model:
185+
raise RuntimeError("A model must be fitted before predict")
186+
187+
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
188+
189+
return self._bqml_model.explain_predict(X)
190+
163191
def score(
164192
self,
165193
X: utils.ArrayType,

bigframes/ml/sql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def ml_predict(self, source_sql: str) -> str:
304304
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
305305
({source_sql}))"""
306306

307+
def ml_explain_predict(self, source_sql: str) -> str:
308+
"""Encode ML.EXPLAIN_PREDICT for BQML"""
309+
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
310+
({source_sql}))"""
311+
307312
def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
308313
"""Encode ML.FORECAST for BQML"""
309314
struct_options_sql = self.struct_options(**struct_options)

tests/system/small/ml/test_core.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,28 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_
260260
)
261261

262262

263+
def test_model_predict_explain(
264+
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
265+
):
266+
predictions = penguins_bqml_linear_model.explain_predict(
267+
new_penguins_df
268+
).to_pandas()
269+
expected = pd.DataFrame(
270+
{
271+
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
272+
"approximation_error": [0.0, 0.0, 0.0],
273+
},
274+
dtype="Float64",
275+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
276+
)
277+
pd.testing.assert_frame_equal(
278+
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
279+
expected,
280+
check_exact=False,
281+
rtol=0.1,
282+
)
283+
284+
263285
def test_model_predict_with_unnamed_index(
264286
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
265287
):
@@ -288,6 +310,39 @@ def test_model_predict_with_unnamed_index(
288310
)
289311

290312

313+
def test_model_predict_explain_with_unnamed_index(
314+
penguins_bqml_linear_model: core.BqmlModel, new_penguins_df
315+
):
316+
# This will result in an index that lacks a name, which the ML library will
317+
# need to persist through the call to ML.PREDICT
318+
new_penguins_df = new_penguins_df.reset_index()
319+
320+
# remove the middle tag number to ensure we're really keeping the unnamed index
321+
new_penguins_df = typing.cast(
322+
bigframes.dataframe.DataFrame,
323+
new_penguins_df[new_penguins_df.tag_number != 1672],
324+
)
325+
326+
predictions = penguins_bqml_linear_model.explain_predict(
327+
new_penguins_df
328+
).to_pandas()
329+
330+
expected = pd.DataFrame(
331+
{
332+
"predicted_body_mass_g": [4030.1, 3177.9],
333+
"approximation_error": [0.0, 0.0],
334+
},
335+
dtype="Float64",
336+
index=pd.Index([0, 2], dtype="Int64"),
337+
)
338+
pd.testing.assert_frame_equal(
339+
predictions[["predicted_body_mass_g", "approximation_error"]].sort_index(),
340+
expected,
341+
check_exact=False,
342+
rtol=0.1,
343+
)
344+
345+
291346
def test_model_detect_anomalies(
292347
penguins_bqml_pca_model: core.BqmlModel, new_penguins_df
293348
):

tests/system/small/ml/test_linear_model.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pandas
1717
import pytest
1818

19+
from bigframes.ml import linear_model
20+
1921

2022
def test_linear_reg_model_score(penguins_linear_model, penguins_df_default_index):
2123
df = penguins_df_default_index.dropna()
@@ -106,6 +108,72 @@ def test_linear_reg_model_predict(penguins_linear_model, new_penguins_df):
106108
)
107109

108110

111+
def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df):
112+
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
113+
assert predictions.shape == (3, 12)
114+
result = predictions[["predicted_body_mass_g", "approximation_error"]]
115+
expected = pandas.DataFrame(
116+
{
117+
"predicted_body_mass_g": [4030.1, 3280.8, 3177.9],
118+
"approximation_error": [
119+
0.0,
120+
0.0,
121+
0.0,
122+
],
123+
},
124+
dtype="Float64",
125+
index=pandas.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
126+
)
127+
pandas.testing.assert_frame_equal(
128+
result.sort_index(),
129+
expected,
130+
check_exact=False,
131+
rtol=0.1,
132+
)
133+
134+
135+
def test_linear_reg_model_predict_params(
136+
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
137+
):
138+
predictions = penguins_linear_model.predict(new_penguins_df).to_pandas()
139+
assert predictions.shape[0] >= 1
140+
prediction_columns = set(predictions.columns)
141+
expected_columns = {
142+
"predicted_body_mass_g",
143+
"species",
144+
"island",
145+
"culmen_length_mm",
146+
"culmen_depth_mm",
147+
"flipper_length_mm",
148+
"body_mass_g",
149+
"sex",
150+
}
151+
assert expected_columns <= prediction_columns
152+
153+
154+
def test_linear_reg_model_predict_explain_params(
155+
penguins_linear_model: linear_model.LinearRegression, new_penguins_df
156+
):
157+
predictions = penguins_linear_model.predict_explain(new_penguins_df).to_pandas()
158+
assert predictions.shape[0] >= 1
159+
prediction_columns = set(predictions.columns)
160+
expected_columns = {
161+
"predicted_body_mass_g",
162+
"top_feature_attributions",
163+
"baseline_prediction_value",
164+
"prediction_value",
165+
"approximation_error",
166+
"species",
167+
"island",
168+
"culmen_length_mm",
169+
"culmen_depth_mm",
170+
"flipper_length_mm",
171+
"body_mass_g",
172+
"sex",
173+
}
174+
assert expected_columns <= prediction_columns
175+
176+
109177
def test_to_gbq_saved_linear_reg_model_scores(
110178
penguins_linear_model, table_id_unique, penguins_df_default_index
111179
):

tests/unit/ml/test_sql.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,18 @@ def test_ml_predict_correct(
342342
)
343343

344344

345+
def test_ml_explain_predict_correct(
346+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
347+
mock_df: bpd.DataFrame,
348+
):
349+
sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql)
350+
assert (
351+
sql
352+
== """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
353+
(input_X_y_sql))"""
354+
)
355+
356+
345357
def test_ml_llm_evaluate_correct(
346358
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
347359
mock_df: bpd.DataFrame,

0 commit comments

Comments
 (0)