1616
1717from bigframes .ml import forecasting
1818
19+ ARIMA_EVALUATE_OUTPUT_COL = [
20+ "non_seasonal_p" ,
21+ "non_seasonal_d" ,
22+ "non_seasonal_q" ,
23+ "log_likelihood" ,
24+ "AIC" ,
25+ "variance" ,
26+ "seasonal_periods" ,
27+ "has_holiday_effect" ,
28+ "has_spikes_and_dips" ,
29+ "has_step_changes" ,
30+ "error_message" ,
31+ ]
32+
1933
2034def test_arima_plus_model_fit_score (
2135 time_series_df_default_index , dataset_id , new_time_series_df
@@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score(
4256 pd .testing .assert_frame_equal (result , expected , check_exact = False , rtol = 0.1 )
4357
4458 # save, load to ensure configuration was kept
45- reloaded_model = model .to_gbq (f"{ dataset_id } .temp_configured_model" , replace = True )
59+ reloaded_model = model .to_gbq (f"{ dataset_id } .temp_arima_plus_model" , replace = True )
60+ assert (
61+ f"{ dataset_id } .temp_arima_plus_model" in reloaded_model ._bqml_model .model_name
62+ )
63+
64+
65+ def test_arima_plus_model_fit_summary (time_series_df_default_index , dataset_id ):
66+ model = forecasting .ARIMAPlus ()
67+ X_train = time_series_df_default_index [["parsed_date" ]]
68+ y_train = time_series_df_default_index [["total_visits" ]]
69+ model .fit (X_train , y_train )
70+
71+ result = model .summary ()
72+ assert result .shape == (1 , 12 )
73+ assert all (column in result .columns for column in ARIMA_EVALUATE_OUTPUT_COL )
74+
75+ # save, load to ensure configuration was kept
76+ reloaded_model = model .to_gbq (f"{ dataset_id } .temp_arima_plus_model" , replace = True )
4677 assert (
47- f"{ dataset_id } .temp_configured_model " in reloaded_model ._bqml_model .model_name
78+ f"{ dataset_id } .temp_arima_plus_model " in reloaded_model ._bqml_model .model_name
4879 )
0 commit comments