@@ -154,6 +154,23 @@ def exog_data(rng):
154
154
return df .set_index ("date" )
155
155
156
156
157
+ @pytest .fixture (scope = "session" )
158
+ def exog_data_mv (rng ):
159
+ # simulate data
160
+ df = pd .DataFrame (
161
+ {
162
+ "date" : pd .date_range (start = "2023-05-01" , end = "2023-05-10" , freq = "D" ),
163
+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
164
+ "y1" : rng .normal (size = (10 ,)),
165
+ "y2" : rng .normal (size = (10 ,)),
166
+ }
167
+ )
168
+
169
+ df .loc [[1 , 3 , 9 ], ["y1" ]] = np .nan
170
+ df .loc [[3 , 5 , 7 ], ["y2" ]] = np .nan
171
+ return df .set_index ("date" )
172
+
173
+
157
174
@pytest .fixture (scope = "session" )
158
175
def exog_ss_mod (exog_data ):
159
176
level_trend = st .LevelTrendComponent (name = "trend" , order = 1 , innovations_order = [0 ])
@@ -168,6 +185,23 @@ def exog_ss_mod(exog_data):
168
185
return combined_model .build ()
169
186
170
187
188
+ @pytest .fixture (scope = "session" )
189
+ def exog_ss_mod_mv (exog_data_mv ):
190
+ level_trend = st .LevelTrendComponent (
191
+ name = "trend" , order = 1 , innovations_order = [0 ], observed_state_names = ["y1" , "y2" ]
192
+ )
193
+ exog = st .RegressionComponent (
194
+ name = "exog" , # Name of this exogenous variable component
195
+ k_exog = 1 , # Only one exogenous variable now
196
+ innovations = False , # Typically fixed effect (no stochastic evolution)
197
+ state_names = exog_data_mv [["x1" ]].columns .tolist (),
198
+ observed_state_names = ["y1" , "y2" ],
199
+ )
200
+
201
+ combined_model = level_trend + exog
202
+ return combined_model .build ()
203
+
204
+
171
205
@pytest .fixture (scope = "session" )
172
206
def ss_mod_multi_component (rng ):
173
207
ll = st .LevelTrendComponent (
@@ -208,6 +242,29 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
208
242
return struct_model
209
243
210
244
245
+ @pytest .fixture (scope = "session" )
246
+ def exog_pymc_mod_mv (exog_ss_mod_mv , exog_data_mv ):
247
+ # define pymc model
248
+ with pm .Model (coords = exog_ss_mod_mv .coords ) as struct_model :
249
+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 4 , dims = ["state" ])
250
+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
251
+
252
+ initial_trend = pm .Normal (
253
+ "initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["endog_trend" , "state_trend" ]
254
+ )
255
+
256
+ data_exog = pm .Data (
257
+ "data_exog" , exog_data_mv ["x1" ].values [:, None ], dims = ["time" , "state_exog" ]
258
+ )
259
+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["endog_exog" , "state_exog" ])
260
+
261
+ exog_ss_mod_mv .build_statespace_graph (
262
+ exog_data_mv [["y1" , "y2" ]], save_kalman_filter_outputs_in_idata = True
263
+ )
264
+
265
+ return struct_model
266
+
267
+
211
268
@pytest .fixture (scope = "session" )
212
269
def pymc_mod_no_exog (ss_mod_no_exog , rng ):
213
270
y = pd .DataFrame (rng .normal (size = (100 , 1 )).astype (floatX ), columns = ["y" ])
@@ -299,6 +356,15 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
299
356
return idata
300
357
301
358
359
+ @pytest .fixture (scope = "session" )
360
+ def idata_exog_mv (exog_pymc_mod_mv , rng , mock_pymc_sample ):
361
+ with exog_pymc_mod_mv :
362
+ idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
363
+ idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
364
+ idata .extend (idata_prior )
365
+ return idata
366
+
367
+
302
368
@pytest .fixture (scope = "session" )
303
369
def idata_no_exog (pymc_mod_no_exog , rng , mock_pymc_sample ):
304
370
with pymc_mod_no_exog :
@@ -1002,6 +1068,51 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
1002
1068
assert_allclose (regression_effect , regression_effect_expected )
1003
1069
1004
1070
1071
+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
1072
+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
1073
+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
1074
+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
1075
+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
1076
+ @pytest .mark .parametrize ("start" , [None , - 1 , 5 ])
1077
+ def test_forecast_with_exog_data_mv (rng , exog_ss_mod_mv , idata_exog_mv , start ):
1078
+ scenario = pd .DataFrame (np .zeros ((10 , 1 )), columns = ["x1" ])
1079
+ scenario .iloc [5 , 0 ] = 1e9
1080
+
1081
+ forecast_idata = exog_ss_mod_mv .forecast (
1082
+ idata_exog_mv , start = start , periods = 10 , random_seed = rng , scenario = scenario
1083
+ )
1084
+
1085
+ components = exog_ss_mod_mv .extract_components_from_idata (forecast_idata )
1086
+ level_y1 = components .forecast_latent .sel (state = "trend[level[y1]]" )
1087
+ level_y2 = components .forecast_latent .sel (state = "trend[level[y2]]" )
1088
+ betas_y1 = components .forecast_latent .sel (state = ["exog[x1[y1]]" ])
1089
+ betas_y2 = components .forecast_latent .sel (state = ["exog[x1[y2]]" ])
1090
+
1091
+ scenario .index .name = "time"
1092
+ scenario_xr_y1 = (
1093
+ scenario .unstack ()
1094
+ .to_xarray ()
1095
+ .rename ({"level_0" : "state" })
1096
+ .assign_coords (state = ["exog[x1[y1]]" ])
1097
+ )
1098
+
1099
+ scenario_xr_y2 = (
1100
+ scenario .unstack ()
1101
+ .to_xarray ()
1102
+ .rename ({"level_0" : "state" })
1103
+ .assign_coords (state = ["exog[x1[y2]]" ])
1104
+ )
1105
+
1106
+ regression_effect_y1 = forecast_idata .forecast_observed .isel (observed_state = 0 ) - level_y1
1107
+ regression_effect_expected_y1 = (betas_y1 * scenario_xr_y1 ).sum (dim = ["state" ])
1108
+
1109
+ regression_effect_y2 = forecast_idata .forecast_observed .isel (observed_state = 1 ) - level_y2
1110
+ regression_effect_expected_y2 = (betas_y2 * scenario_xr_y2 ).sum (dim = ["state" ])
1111
+
1112
+ np .testing .assert_allclose (regression_effect_y1 , regression_effect_expected_y1 )
1113
+ np .testing .assert_allclose (regression_effect_y2 , regression_effect_expected_y2 )
1114
+
1115
+
1005
1116
@pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
1006
1117
@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
1007
1118
@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
0 commit comments