Skip to content

Commit 05f8b4d

Browse files
rey-espchelsea-lintswastgcf-owl-bot[bot]
authored
feat: add ARIMAPlus.predict_explain() to generate forecasts with explanation columns (#1177)
* feat: create arima_plus_predict_attribution method * tmp: debug notes for time_series_arima_plus_model.predict_attribution * update test_arima_plus_predict_explain_default test and create test_arima_plus_predict_explain_params test * Merge branch 'ml-predict-explain' of github.com:googleapis/python-bigquery-dataframes into ml-predict-explain * update test_arima_plus_predict_explain_params test * Revert "tmp: debug notes for time_series_arima_plus_model.predict_attribution" This reverts commit f6dd455. * format and lint * Update bigframes/ml/forecasting.py Co-authored-by: Tim Sweña (Swast) <[email protected]> * update predict explain params test * update test * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * add unit test file - bare bones * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fixed tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * lint * lint * fix test: float -> int --------- Co-authored-by: Chelsea Lin <[email protected]> Co-authored-by: Tim Sweña (Swast) <[email protected]> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 0d8a16b commit 05f8b4d

File tree

5 files changed

+174
-0
lines changed

5 files changed

+174
-0
lines changed

bigframes/ml/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
172172
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
173173
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
174174

175+
def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
176+
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
177+
struct_options=options
178+
)
179+
return self._session.read_gbq(
180+
sql, index_col="time_series_timestamp"
181+
).reset_index()
182+
175183
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
176184
sql = self._model_manipulation_sql_generator.ml_evaluate(
177185
input_data.sql if (input_data is not None) else None

bigframes/ml/forecasting.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,43 @@ def predict(
253253
options={"horizon": horizon, "confidence_level": confidence_level}
254254
)
255255

256+
def predict_explain(
257+
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
258+
) -> bpd.DataFrame:
259+
"""Explain Forecast time series at future horizon.
260+
261+
.. note::
262+
263+
Output matches that of the BigQuery ML.EXPLAIN_FORECAST function.
264+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-forecast
265+
266+
Args:
267+
X (default None):
268+
ignored, to be compatible with other APIs.
269+
horizon (int, default: 3):
270+
an int value that specifies the number of time points to forecast.
271+
The default value is 3, and the maximum value is 1000.
272+
confidence_level (float, default 0.95):
273+
A float value that specifies percentage of the future values that fall in the prediction interval.
274+
The valid input range is [0.0, 1.0).
275+
276+
Returns:
277+
bigframes.dataframe.DataFrame: The predicted DataFrames.
278+
"""
279+
if horizon < 1:
280+
raise ValueError(f"horizon must be at least 1, but is {horizon}.")
281+
if confidence_level < 0.0 or confidence_level >= 1.0:
282+
raise ValueError(
283+
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
284+
)
285+
286+
if not self._bqml_model:
287+
raise RuntimeError("A model must be fitted before predict")
288+
289+
return self._bqml_model.explain_forecast(
290+
options={"horizon": horizon, "confidence_level": confidence_level}
291+
)
292+
256293
@property
257294
def coef_(
258295
self,

bigframes/ml/sql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,14 @@ def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
310310
return f"""SELECT * FROM ML.FORECAST(MODEL {self._model_ref_sql()},
311311
{struct_options_sql})"""
312312

313+
def ml_explain_forecast(
314+
self, struct_options: Mapping[str, Union[int, float]]
315+
) -> str:
316+
"""Encode ML.EXPLAIN_FORECAST for BQML"""
317+
struct_options_sql = self.struct_options(**struct_options)
318+
return f"""SELECT * FROM ML.EXPLAIN_FORECAST(MODEL {self._model_ref_sql()},
319+
{struct_options_sql})"""
320+
313321
def ml_generate_text(
314322
self, source_sql: str, struct_options: Mapping[str, Union[int, float]]
315323
) -> str:

tests/system/small/ml/test_forecasting.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,42 @@ def test_arima_plus_predict_default(
6565
)
6666

6767

68+
def test_arima_plus_predict_explain_default(
69+
time_series_arima_plus_model: forecasting.ARIMAPlus,
70+
):
71+
utc = pytz.utc
72+
predictions = time_series_arima_plus_model.predict_explain().to_pandas()
73+
assert predictions.shape[0] == 369
74+
predictions = predictions[
75+
predictions["time_series_type"] == "forecast"
76+
].reset_index(drop=True)
77+
assert predictions.shape[0] == 3
78+
result = predictions[["time_series_timestamp", "time_series_data"]]
79+
expected = pd.DataFrame(
80+
{
81+
"time_series_timestamp": [
82+
datetime(2017, 8, 2, tzinfo=utc),
83+
datetime(2017, 8, 3, tzinfo=utc),
84+
datetime(2017, 8, 4, tzinfo=utc),
85+
],
86+
"time_series_data": [2727.693349, 2595.290749, 2370.86767],
87+
}
88+
)
89+
expected["time_series_data"] = expected["time_series_data"].astype(
90+
pd.Float64Dtype()
91+
)
92+
expected["time_series_timestamp"] = expected["time_series_timestamp"].astype(
93+
pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
94+
)
95+
96+
pd.testing.assert_frame_equal(
97+
result,
98+
expected,
99+
rtol=0.1,
100+
check_index_type=False,
101+
)
102+
103+
68104
def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARIMAPlus):
69105
utc = pytz.utc
70106
predictions = time_series_arima_plus_model.predict(
@@ -96,6 +132,33 @@ def test_arima_plus_predict_params(time_series_arima_plus_model: forecasting.ARI
96132
)
97133

98134

135+
def test_arima_plus_predict_explain_params(
136+
time_series_arima_plus_model: forecasting.ARIMAPlus,
137+
):
138+
predictions = time_series_arima_plus_model.predict_explain(
139+
horizon=4, confidence_level=0.9
140+
).to_pandas()
141+
assert predictions.shape[0] >= 1
142+
prediction_columns = set(predictions.columns)
143+
expected_columns = {
144+
"time_series_timestamp",
145+
"time_series_type",
146+
"time_series_data",
147+
"time_series_adjusted_data",
148+
"standard_error",
149+
"confidence_level",
150+
"prediction_interval_lower_bound",
151+
"trend",
152+
"seasonal_period_yearly",
153+
"seasonal_period_quarterly",
154+
"seasonal_period_monthly",
155+
"seasonal_period_weekly",
156+
"seasonal_period_daily",
157+
"holiday_effect",
158+
}
159+
assert expected_columns <= prediction_columns
160+
161+
99162
def test_arima_plus_detect_anomalies(
100163
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
101164
):

tests/unit/ml/test_forecasting.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
17+
import pytest
18+
19+
from bigframes.ml import forecasting
20+
21+
22+
def test_predict_explain_low_confidence_level():
23+
confidence_level = -0.5
24+
25+
model = forecasting.ARIMAPlus()
26+
27+
with pytest.raises(
28+
ValueError,
29+
match=re.escape(
30+
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
31+
),
32+
):
33+
model.predict_explain(horizon=4, confidence_level=confidence_level)
34+
35+
36+
def test_predict_high_explain_confidence_level():
37+
confidence_level = 2.1
38+
39+
model = forecasting.ARIMAPlus()
40+
41+
with pytest.raises(
42+
ValueError,
43+
match=re.escape(
44+
f"confidence_level must be [0.0, 1.0), but is {confidence_level}."
45+
),
46+
):
47+
model.predict_explain(horizon=4, confidence_level=confidence_level)
48+
49+
50+
def test_predict_explain_low_horizon():
51+
horizon = -1
52+
53+
model = forecasting.ARIMAPlus()
54+
55+
with pytest.raises(
56+
ValueError, match=f"horizon must be at least 1, but is {horizon}."
57+
):
58+
model.predict_explain(horizon=horizon, confidence_level=0.9)

0 commit comments

Comments
 (0)