@@ -114,6 +114,18 @@ def pymc_mod(ss_mod):
114114 return pymc_mod
115115
116116
117+ @pytest .fixture (scope = "session" )
118+ def ss_mod_no_exog (rng ):
119+ ll = st .LevelTrendComponent (order = 2 , innovations_order = 1 )
120+ return ll .build ()
121+
122+
123+ @pytest .fixture (scope = "session" )
124+ def ss_mod_no_exog_dt (rng ):
125+ ll = st .LevelTrendComponent (order = 2 , innovations_order = 1 )
126+ return ll .build ()
127+
128+
117129@pytest .fixture (scope = "session" )
118130def exog_ss_mod (rng ):
119131 ll = st .LevelTrendComponent ()
@@ -143,6 +155,42 @@ def exog_pymc_mod(exog_ss_mod, rng):
143155 return m
144156
145157
158+ @pytest .fixture (scope = "session" )
159+ def pymc_mod_no_exog (ss_mod_no_exog , rng ):
160+ y = pd .DataFrame (rng .normal (size = (100 , 1 )).astype (floatX ), columns = ["y" ])
161+
162+ with pm .Model (coords = ss_mod_no_exog .coords ) as m :
163+ initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
164+ P0_sigma = pm .Exponential ("P0_sigma" , 1 )
165+ P0 = pm .Deterministic (
166+ "P0" , pt .eye (ss_mod_no_exog .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
167+ )
168+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
169+ ss_mod_no_exog .build_statespace_graph (y )
170+
171+ return m
172+
173+
174+ @pytest .fixture (scope = "session" )
175+ def pymc_mod_no_exog_dt (ss_mod_no_exog_dt , rng ):
176+ y = pd .DataFrame (
177+ rng .normal (size = (100 , 1 )).astype (floatX ),
178+ columns = ["y" ],
179+ index = pd .date_range ("2020-01-01" , periods = 100 , freq = "D" ),
180+ )
181+
182+ with pm .Model (coords = ss_mod_no_exog_dt .coords ) as m :
183+ initial_trend = pm .Normal ("initial_trend" , dims = ["trend_state" ])
184+ P0_sigma = pm .Exponential ("P0_sigma" , 1 )
185+ P0 = pm .Deterministic (
186+ "P0" , pt .eye (ss_mod_no_exog_dt .k_states ) * P0_sigma , dims = ["state" , "state_aux" ]
187+ )
188+ sigma_trend = pm .Exponential ("sigma_trend" , 1 , dims = ["trend_shock" ])
189+ ss_mod_no_exog_dt .build_statespace_graph (y )
190+
191+ return m
192+
193+
146194@pytest .fixture (scope = "session" )
147195def idata (pymc_mod , rng ):
148196 with pymc_mod :
@@ -162,6 +210,24 @@ def idata_exog(exog_pymc_mod, rng):
162210 return idata
163211
164212
213+ @pytest .fixture (scope = "session" )
214+ def idata_no_exog (pymc_mod_no_exog , rng ):
215+ with pymc_mod_no_exog :
216+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
217+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
218+ idata .extend (idata_prior )
219+ return idata
220+
221+
222+ @pytest .fixture (scope = "session" )
223+ def idata_no_exog_dt (pymc_mod_no_exog_dt , rng ):
224+ with pymc_mod_no_exog_dt :
225+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
226+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
227+ idata .extend (idata_prior )
228+ return idata
229+
230+
165231def test_invalid_filter_name_raises ():
166232 msg = "The following are valid filter types: " + ", " .join (list (FILTER_FACTORY .keys ()))
167233 with pytest .raises (NotImplementedError , match = msg ):
@@ -664,28 +730,75 @@ def test_invalid_scenarios():
664730 ss_mod ._validate_scenario_data (scenario )
665731
666732
733+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
667734@pytest .mark .parametrize ("filter_output" , ["predicted" , "filtered" , "smoothed" ])
668- def test_forecast (filter_output , ss_mod , idata , rng ):
669- time_idx = idata .posterior .coords ["time" ].values
670- forecast_idata = ss_mod .forecast (
671- idata , start = time_idx [- 1 ], periods = 10 , filter_output = filter_output , random_seed = rng
735+ @pytest .mark .parametrize (
736+ "mod_name, idata_name, start, end, periods" ,
737+ [
738+ ("ss_mod_no_exog" , "idata_no_exog" , None , None , 10 ),
739+ ("ss_mod_no_exog" , "idata_no_exog" , - 1 , None , 10 ),
740+ ("ss_mod_no_exog" , "idata_no_exog" , 10 , None , 10 ),
741+ ("ss_mod_no_exog" , "idata_no_exog" , 10 , 21 , None ),
742+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , None , None , 10 ),
743+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , - 1 , None , 10 ),
744+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , 10 , None , 10 ),
745+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , 10 , "2020-01-21" , None ),
746+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , "2020-03-01" , "2020-03-11" , None ),
747+ ("ss_mod_no_exog_dt" , "idata_no_exog_dt" , "2020-03-01" , None , 10 ),
748+ ],
749+ ids = [
750+ "range_default" ,
751+ "range_negative" ,
752+ "range_int" ,
753+ "range_end" ,
754+ "datetime_default" ,
755+ "datetime_negative" ,
756+ "datetime_int" ,
757+ "datetime_int_end" ,
758+ "datetime_datetime_end" ,
759+ "datetime_datetime" ,
760+ ],
761+ )
762+ def test_forecast (filter_output , mod_name , idata_name , start , end , periods , rng , request ):
763+ mod = request .getfixturevalue (mod_name )
764+ idata = request .getfixturevalue (idata_name )
765+ time_idx = mod ._get_fit_time_index ()
766+ is_datetime = isinstance (time_idx , pd .DatetimeIndex )
767+
768+ if isinstance (start , str ):
769+ t0 = pd .Timestamp (start )
770+ elif isinstance (start , int ):
771+ t0 = time_idx [start ]
772+ else :
773+ t0 = time_idx [- 1 ]
774+
775+ delta = time_idx .freq if is_datetime else 1
776+
777+ forecast_idata = mod .forecast (
778+ idata , start = start , end = end , periods = periods , filter_output = filter_output , random_seed = rng
672779 )
673780
674- assert forecast_idata .coords ["time" ].values .shape == (10 ,)
781+ forecast_idx = forecast_idata .coords ["time" ].values
782+ forecast_idx = pd .DatetimeIndex (forecast_idx ) if is_datetime else pd .Index (forecast_idx )
783+
784+ assert forecast_idx .shape == (10 ,)
675785 assert forecast_idata .forecast_latent .dims == ("chain" , "draw" , "time" , "state" )
676786 assert forecast_idata .forecast_observed .dims == ("chain" , "draw" , "time" , "observed_state" )
677787
678788 assert not np .any (np .isnan (forecast_idata .forecast_latent .values ))
679789 assert not np .any (np .isnan (forecast_idata .forecast_observed .values ))
680790
791+ assert forecast_idx [0 ] == (t0 + delta )
792+
681793
682794@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
683- def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog ):
795+ @pytest .mark .parametrize ("start" , [None , - 1 , 10 ])
796+ def test_forecast_with_exog_data (rng , exog_ss_mod , idata_exog , start ):
684797 scenario = pd .DataFrame (np .zeros ((10 , 3 )), columns = ["a" , "b" , "c" ])
685798 scenario .iloc [5 , 0 ] = 1e9
686799
687800 forecast_idata = exog_ss_mod .forecast (
688- idata_exog , periods = 10 , random_seed = rng , scenario = scenario
801+ idata_exog , start = start , periods = 10 , random_seed = rng , scenario = scenario
689802 )
690803
691804 components = exog_ss_mod .extract_components_from_idata (forecast_idata )
0 commit comments