Skip to content

Commit 73e997b

Browse files
authored
feat: add ARIMA_EVAULATE options in forecasting models (#336)
* feat: add ARIMA_EVAULATE options in forecasting models * feat: add summary method * fix minor errors * fix failed tests * address comments
1 parent 47c3285 commit 73e997b

File tree

6 files changed

+124
-2
lines changed

6 files changed

+124
-2
lines changed

bigframes/ml/core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
136136

137137
return self._session.read_gbq(sql)
138138

139+
def arima_evaluate(self, show_all_candidate_models: bool = False):
140+
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
141+
show_all_candidate_models
142+
)
143+
144+
return self._session.read_gbq(sql)
145+
139146
def centroids(self) -> bpd.DataFrame:
140147
assert self._model.model_type == "KMEANS"
141148

bigframes/ml/forecasting.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,31 @@ def score(
151151
input_data = X.join(y, how="outer")
152152
return self._bqml_model.evaluate(input_data)
153153

154+
def summary(
155+
self,
156+
show_all_candidate_models: bool = False,
157+
) -> bpd.DataFrame:
158+
"""Summary of the evaluation metrics of the time series model.
159+
160+
.. note::
161+
162+
Output matches that of the BigQuery ML.ARIMA_EVALUATE function.
163+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-evaluate
164+
for the outputs relevant to this model type.
165+
166+
Args:
167+
show_all_candidate_models (bool, default to False):
168+
Whether to show evaluation metrics or an error message for either
169+
all candidate models or for only the best model with the lowest
170+
AIC. Default to False.
171+
172+
Returns:
173+
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
174+
"""
175+
if not self._bqml_model:
176+
raise RuntimeError("A model must be fitted before score")
177+
return self._bqml_model.arima_evaluate(show_all_candidate_models)
178+
154179
def to_gbq(self, model_name: str, replace: bool = False) -> ARIMAPlus:
155180
"""Save the model to BigQuery.
156181

bigframes/ml/sql.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
260260
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
261261
({source_sql}))"""
262262

263+
# ML evaluation TVFs
264+
def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str:
265+
"""Encode ML.ARMIA_EVALUATE for BQML"""
266+
return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`,
267+
STRUCT({show_all_candidate_models} AS show_all_candidate_models))"""
268+
263269
def ml_centroids(self) -> str:
264270
"""Encode ML.CENTROIDS for BQML"""
265271
return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)"""

tests/system/large/ml/test_forecasting.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616

1717
from bigframes.ml import forecasting
1818

19+
ARIMA_EVALUATE_OUTPUT_COL = [
20+
"non_seasonal_p",
21+
"non_seasonal_d",
22+
"non_seasonal_q",
23+
"log_likelihood",
24+
"AIC",
25+
"variance",
26+
"seasonal_periods",
27+
"has_holiday_effect",
28+
"has_spikes_and_dips",
29+
"has_step_changes",
30+
"error_message",
31+
]
32+
1933

2034
def test_arima_plus_model_fit_score(
2135
time_series_df_default_index, dataset_id, new_time_series_df
@@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score(
4256
pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1)
4357

4458
# save, load to ensure configuration was kept
45-
reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True)
59+
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
60+
assert (
61+
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
62+
)
63+
64+
65+
def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
66+
model = forecasting.ARIMAPlus()
67+
X_train = time_series_df_default_index[["parsed_date"]]
68+
y_train = time_series_df_default_index[["total_visits"]]
69+
model.fit(X_train, y_train)
70+
71+
result = model.summary()
72+
assert result.shape == (1, 12)
73+
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
74+
75+
# save, load to ensure configuration was kept
76+
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
4677
assert (
47-
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
78+
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
4879
)

tests/system/small/ml/test_forecasting.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020

2121
from bigframes.ml import forecasting
2222

23+
ARIMA_EVALUATE_OUTPUT_COL = [
24+
"non_seasonal_p",
25+
"non_seasonal_d",
26+
"non_seasonal_q",
27+
"log_likelihood",
28+
"AIC",
29+
"variance",
30+
"seasonal_periods",
31+
"has_holiday_effect",
32+
"has_spikes_and_dips",
33+
"has_step_changes",
34+
"error_message",
35+
]
36+
2337

2438
def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus):
2539
utc = pytz.utc
@@ -104,6 +118,24 @@ def test_model_score(
104118
)
105119

106120

121+
def test_model_summary(
122+
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
123+
):
124+
result = time_series_arima_plus_model.summary()
125+
assert result.shape == (1, 12)
126+
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
127+
128+
129+
def test_model_summary_show_all_candidates(
130+
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
131+
):
132+
result = time_series_arima_plus_model.summary(
133+
show_all_candidate_models=True,
134+
)
135+
assert result.shape[0] > 1
136+
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
137+
138+
107139
def test_model_score_series(
108140
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
109141
):
@@ -126,3 +158,11 @@ def test_model_score_series(
126158
rtol=0.1,
127159
check_index_type=False,
128160
)
161+
162+
163+
def test_model_summary_series(
164+
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
165+
):
166+
result = time_series_arima_plus_model.summary()
167+
assert result.shape == (1, 12)
168+
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)

tests/unit/ml/test_sql.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,19 @@ def test_ml_evaluate_produces_correct_sql(
273273
)
274274

275275

276+
def test_ml_arima_evaluate_produces_correct_sql(
277+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
278+
):
279+
sql = model_manipulation_sql_generator.ml_arima_evaluate(
280+
show_all_candidate_models=True
281+
)
282+
assert (
283+
sql
284+
== """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
285+
STRUCT(True AS show_all_candidate_models))"""
286+
)
287+
288+
276289
def test_ml_evaluate_no_source_produces_correct_sql(
277290
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
278291
):

0 commit comments

Comments
 (0)