@@ -154,6 +154,23 @@ def exog_data(rng):
154154 return df .set_index ("date" )
155155
156156
157+ @pytest .fixture (scope = "session" )
158+ def exog_data_mv (rng ):
159+ # simulate data
160+ df = pd .DataFrame (
161+ {
162+ "date" : pd .date_range (start = "2023-05-01" , end = "2023-05-10" , freq = "D" ),
163+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
164+ "y1" : rng .normal (size = (10 ,)),
165+ "y2" : rng .normal (size = (10 ,)),
166+ }
167+ )
168+
169+ df .loc [[1 , 3 , 9 ], ["y1" ]] = np .nan
170+ df .loc [[3 , 5 , 7 ], ["y2" ]] = np .nan
171+ return df .set_index ("date" )
172+
173+
157174@pytest .fixture (scope = "session" )
158175def exog_ss_mod (exog_data ):
159176 level_trend = st .LevelTrendComponent (name = "trend" , order = 1 , innovations_order = [0 ])
@@ -168,6 +185,23 @@ def exog_ss_mod(exog_data):
168185 return combined_model .build ()
169186
170187
188+ @pytest .fixture (scope = "session" )
189+ def exog_ss_mod_mv (exog_data_mv ):
190+ level_trend = st .LevelTrendComponent (
191+ name = "trend" , order = 1 , innovations_order = [0 ], observed_state_names = ["y1" , "y2" ]
192+ )
193+ exog = st .RegressionComponent (
194+ name = "exog" , # Name of this exogenous variable component
195+ k_exog = 1 , # Only one exogenous variable now
196+ innovations = False , # Typically fixed effect (no stochastic evolution)
197+ state_names = exog_data_mv [["x1" ]].columns .tolist (),
198+ observed_state_names = ["y1" , "y2" ],
199+ )
200+
201+ combined_model = level_trend + exog
202+ return combined_model .build ()
203+
204+
171205@pytest .fixture (scope = "session" )
172206def ss_mod_multi_component (rng ):
173207 ll = st .LevelTrendComponent (
@@ -208,6 +242,29 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
208242 return struct_model
209243
210244
245+ @pytest .fixture (scope = "session" )
246+ def exog_pymc_mod_mv (exog_ss_mod_mv , exog_data_mv ):
247+ # define pymc model
248+ with pm .Model (coords = exog_ss_mod_mv .coords ) as struct_model :
249+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 4 , dims = ["state" ])
250+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
251+
252+ initial_trend = pm .Normal (
253+ "initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["endog_trend" , "state_trend" ]
254+ )
255+
256+ data_exog = pm .Data (
257+ "data_exog" , exog_data_mv ["x1" ].values [:, None ], dims = ["time" , "state_exog" ]
258+ )
259+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["endog_exog" , "state_exog" ])
260+
261+ exog_ss_mod_mv .build_statespace_graph (
262+ exog_data_mv [["y1" , "y2" ]], save_kalman_filter_outputs_in_idata = True
263+ )
264+
265+ return struct_model
266+
267+
211268@pytest .fixture (scope = "session" )
212269def pymc_mod_no_exog (ss_mod_no_exog , rng ):
213270 y = pd .DataFrame (rng .normal (size = (100 , 1 )).astype (floatX ), columns = ["y" ])
@@ -299,6 +356,15 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
299356 return idata
300357
301358
359+ @pytest .fixture (scope = "session" )
360+ def idata_exog_mv (exog_pymc_mod_mv , rng , mock_pymc_sample ):
361+ with exog_pymc_mod_mv :
362+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
363+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
364+ idata .extend (idata_prior )
365+ return idata
366+
367+
302368@pytest .fixture (scope = "session" )
303369def idata_no_exog (pymc_mod_no_exog , rng , mock_pymc_sample ):
304370 with pymc_mod_no_exog :
@@ -1002,6 +1068,51 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
10021068 assert_allclose (regression_effect , regression_effect_expected )
10031069
10041070
1071+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
1072+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
1073+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
1074+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
1075+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
1076+ @pytest .mark .parametrize ("start" , [None , - 1 , 5 ])
1077+ def test_forecast_with_exog_data_mv (rng , exog_ss_mod_mv , idata_exog_mv , start ):
1078+ scenario = pd .DataFrame (np .zeros ((10 , 1 )), columns = ["x1" ])
1079+ scenario .iloc [5 , 0 ] = 1e9
1080+
1081+ forecast_idata = exog_ss_mod_mv .forecast (
1082+ idata_exog_mv , start = start , periods = 10 , random_seed = rng , scenario = scenario
1083+ )
1084+
1085+ components = exog_ss_mod_mv .extract_components_from_idata (forecast_idata )
1086+ level_y1 = components .forecast_latent .sel (state = "trend[level[y1]]" )
1087+ level_y2 = components .forecast_latent .sel (state = "trend[level[y2]]" )
1088+ betas_y1 = components .forecast_latent .sel (state = ["exog[x1[y1]]" ])
1089+ betas_y2 = components .forecast_latent .sel (state = ["exog[x1[y2]]" ])
1090+
1091+ scenario .index .name = "time"
1092+ scenario_xr_y1 = (
1093+ scenario .unstack ()
1094+ .to_xarray ()
1095+ .rename ({"level_0" : "state" })
1096+ .assign_coords (state = ["exog[x1[y1]]" ])
1097+ )
1098+
1099+ scenario_xr_y2 = (
1100+ scenario .unstack ()
1101+ .to_xarray ()
1102+ .rename ({"level_0" : "state" })
1103+ .assign_coords (state = ["exog[x1[y2]]" ])
1104+ )
1105+
1106+ regression_effect_y1 = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level_y1
1107+ regression_effect_expected_y1 = (betas_y1 * scenario_xr_y1 ).sum (dim = ["state" ])
1108+
1109+ regression_effect_y2 = forecast_idata .forecast_observed .isel (observed_state = 1 ) - level_y2
1110+ regression_effect_expected_y2 = (betas_y2 * scenario_xr_y2 ).sum (dim = ["state" ])
1111+
1112+ np .testing .assert_allclose (regression_effect_y1 , regression_effect_expected_y1 )
1113+ np .testing .assert_allclose (regression_effect_y2 , regression_effect_expected_y2 )
1114+
1115+
10051116@pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
10061117@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
10071118@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
0 commit comments