99import statsmodels .api as sm
1010
1111from numpy .testing import assert_allclose , assert_array_less
12+ from pymc .model .transform .optimization import freeze_dims_and_data
1213
1314from pymc_extras .statespace import BayesianVARMAX
1415from pymc_extras .statespace .utils .constants import SHORT_NAME_TO_LONG
@@ -203,6 +204,10 @@ def test_create_varmax_with_exogenous(data):
203204 assert mod .k_exog == 2
204205 assert mod .exog_state_names == ["exogenous_0" , "exogenous_1" ]
205206 assert mod .data_names == ["exogenous_data" ]
207+ assert mod .param_dims ["beta_exog" ] == ("observed_state" , "exogenous" )
208+ assert mod .coords ["exogenous" ] == ["exogenous_0" , "exogenous_1" ]
209+ assert mod .param_info ["beta_exog" ]["shape" ] == (mod .k_endog , 2 )
210+ assert mod .param_info ["beta_exog" ]["dims" ] == ("observed_state" , "exogenous" )
206211
207212 # Case 2: exog_state_names as list, k_exog is None
208213 mod = BayesianVARMAX (
@@ -216,6 +221,10 @@ def test_create_varmax_with_exogenous(data):
216221 assert mod .k_exog == 2
217222 assert mod .exog_state_names == ["foo" , "bar" ]
218223 assert mod .data_names == ["exogenous_data" ]
224+ assert mod .param_dims ["beta_exog" ] == ("observed_state" , "exogenous" )
225+ assert mod .coords ["exogenous" ] == ["foo" , "bar" ]
226+ assert mod .param_info ["beta_exog" ]["shape" ] == (mod .k_endog , 2 )
227+ assert mod .param_info ["beta_exog" ]["dims" ] == ("observed_state" , "exogenous" )
219228
220229 # Case 3: k_exog as int, exog_state_names as list (matching)
221230 mod = BayesianVARMAX (
@@ -230,6 +239,10 @@ def test_create_varmax_with_exogenous(data):
230239 assert mod .k_exog == 2
231240 assert mod .exog_state_names == ["a" , "b" ]
232241 assert mod .data_names == ["exogenous_data" ]
242+ assert mod .param_dims ["beta_exog" ] == ("observed_state" , "exogenous" )
243+ assert mod .coords ["exogenous" ] == ["a" , "b" ]
244+ assert mod .param_info ["beta_exog" ]["shape" ] == (mod .k_endog , 2 )
245+ assert mod .param_info ["beta_exog" ]["dims" ] == ("observed_state" , "exogenous" )
233246
234247 # Case 4: k_exog as dict, exog_state_names is None
235248 k_exog = {"observed_0" : 2 , "observed_1" : 1 , "observed_2" : 0 }
@@ -252,6 +265,25 @@ def test_create_varmax_with_exogenous(data):
252265 "observed_1_exogenous_data" ,
253266 "observed_2_exogenous_data" ,
254267 ]
268+ assert mod .param_dims ["beta_observed_0" ] == ("exogenous_observed_0" ,)
269+ assert mod .param_dims ["beta_observed_1" ] == ("exogenous_observed_1" ,)
270+ assert (
271+ "beta_observed_2" not in mod .param_dims
272+ or mod .param_info .get ("beta_observed_2" ) is None
273+ or mod .param_info .get ("beta_observed_2" , {}).get ("shape" , (0 ,))[0 ] == 0
274+ )
275+
276+ assert mod .coords ["exogenous_observed_0" ] == [
277+ "observed_0_exogenous_0" ,
278+ "observed_0_exogenous_1" ,
279+ ]
280+ assert mod .coords ["exogenous_observed_1" ] == ["observed_1_exogenous_0" ]
281+ assert "exogenous_observed_2" in mod .coords and mod .coords ["exogenous_observed_2" ] == []
282+
283+ assert mod .param_info ["beta_observed_0" ]["shape" ] == (2 ,)
284+ assert mod .param_info ["beta_observed_0" ]["dims" ] == ("exogenous_observed_0" ,)
285+ assert mod .param_info ["beta_observed_1" ]["shape" ] == (1 ,)
286+ assert mod .param_info ["beta_observed_1" ]["dims" ] == ("exogenous_observed_1" ,)
255287
256288 # Case 5: exog_state_names as dict, k_exog is None
257289 exog_state_names = {"observed_0" : ["a" , "b" ], "observed_1" : ["c" ], "observed_2" : []}
@@ -270,6 +302,22 @@ def test_create_varmax_with_exogenous(data):
270302 "observed_1_exogenous_data" ,
271303 "observed_2_exogenous_data" ,
272304 ]
305+ assert mod .param_dims ["beta_observed_0" ] == ("exogenous_observed_0" ,)
306+ assert mod .param_dims ["beta_observed_1" ] == ("exogenous_observed_1" ,)
307+ assert (
308+ "beta_observed_2" not in mod .param_dims
309+ or mod .param_info .get ("beta_observed_2" ) is None
310+ or mod .param_info .get ("beta_observed_2" , {}).get ("shape" , (0 ,))[0 ] == 0
311+ )
312+
313+ assert mod .coords ["exogenous_observed_0" ] == ["a" , "b" ]
314+ assert mod .coords ["exogenous_observed_1" ] == ["c" ]
315+ assert "exogenous_observed_2" in mod .coords and mod .coords ["exogenous_observed_2" ] == []
316+
317+ assert mod .param_info ["beta_observed_0" ]["shape" ] == (2 ,)
318+ assert mod .param_info ["beta_observed_0" ]["dims" ] == ("exogenous_observed_0" ,)
319+ assert mod .param_info ["beta_observed_1" ]["shape" ] == (1 ,)
320+ assert mod .param_info ["beta_observed_1" ]["dims" ] == ("exogenous_observed_1" ,)
273321
274322 # Case 6: k_exog as dict, exog_state_names as dict (matching)
275323 k_exog = {"observed_0" : 2 , "observed_1" : 1 }
@@ -286,6 +334,14 @@ def test_create_varmax_with_exogenous(data):
286334 assert mod .k_exog == k_exog
287335 assert mod .exog_state_names == exog_state_names
288336 assert mod .data_names == ["observed_0_exogenous_data" , "observed_1_exogenous_data" ]
337+ assert mod .param_dims ["beta_observed_0" ] == ("exogenous_observed_0" ,)
338+ assert mod .param_dims ["beta_observed_1" ] == ("exogenous_observed_1" ,)
339+ assert mod .coords ["exogenous_observed_0" ] == ["a" , "b" ]
340+ assert mod .coords ["exogenous_observed_1" ] == ["c" ]
341+ assert mod .param_info ["beta_observed_0" ]["shape" ] == (2 ,)
342+ assert mod .param_info ["beta_observed_0" ]["dims" ] == ("exogenous_observed_0" ,)
343+ assert mod .param_info ["beta_observed_1" ]["shape" ] == (1 ,)
344+ assert mod .param_info ["beta_observed_1" ]["dims" ] == ("exogenous_observed_1" ,)
289345
290346 # Error: k_exog as int, exog_state_names as list (length mismatch)
291347 with pytest .raises (
@@ -348,3 +404,108 @@ def test_create_varmax_with_exogenous(data):
348404 measurement_error = False ,
349405 stationary_initialization = False ,
350406 )
407+
408+
409+ @pytest .mark .parametrize (
410+ "k_exog, exog_state_names" ,
411+ [
412+ (2 , None ),
413+ (None , ["foo" , "bar" ]),
414+ (None , {"y1" : ["a" , "b" ], "y2" : ["c" ]}),
415+ ],
416+ ids = ["k_exog_int" , "exog_state_names_list" , "exog_state_names_dict" ],
417+ )
418+ @pytest .mark .filterwarnings ("ignore::UserWarning" )
419+ def test_varmax_with_exog (rng , k_exog , exog_state_names ):
420+ endog_names = ["y1" , "y2" , "y3" ]
421+ n_obs = 50
422+ time_idx = pd .date_range (start = "2020-01-01" , periods = n_obs , freq = "D" )
423+
424+ y = rng .normal (size = (n_obs , len (endog_names )))
425+ df = pd .DataFrame (y , columns = endog_names , index = time_idx ).astype (floatX )
426+
427+ if isinstance (exog_state_names , dict ):
428+ exog_data = {
429+ f"{ name } _exogenous_data" : pd .DataFrame (
430+ rng .normal (size = (n_obs , len (exog_names ))).astype (floatX ),
431+ columns = exog_names ,
432+ index = time_idx ,
433+ )
434+ for name , exog_names in exog_state_names .items ()
435+ }
436+ else :
437+ exog_names = exog_state_names or [f"exogenous_{ i } " for i in range (k_exog )]
438+ exog_data = {
439+ "exogenous_data" : pd .DataFrame (
440+ rng .normal (size = (n_obs , k_exog or len (exog_state_names ))).astype (floatX ),
441+ columns = exog_names ,
442+ index = time_idx ,
443+ )
444+ }
445+
446+ mod = BayesianVARMAX (
447+ endog_names = endog_names ,
448+ order = (1 , 0 ),
449+ k_exog = k_exog ,
450+ exog_state_names = exog_state_names ,
451+ verbose = True ,
452+ measurement_error = False ,
453+ stationary_initialization = False ,
454+ mode = "JAX" ,
455+ )
456+
457+ with pm .Model (coords = mod .coords ) as m :
458+ for var_name , data in exog_data .items ():
459+ pm .Data (var_name , data , dims = mod .data_info [var_name ]["dims" ])
460+
461+ x0 = pm .Deterministic ("x0" , pt .zeros (mod .k_states ), dims = mod .param_dims ["x0" ])
462+ P0_diag = pm .Exponential ("P0_diag" , 1.0 , dims = mod .param_dims ["P0" ][0 ])
463+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = mod .param_dims ["P0" ])
464+
465+ ar_params = pm .Normal ("ar_params" , mu = 0 , sigma = 1 , dims = mod .param_dims ["ar_params" ])
466+ state_cov_diag = pm .Exponential ("state_cov_diag" , 1.0 , dims = mod .param_dims ["state_cov" ][0 ])
467+ state_cov = pm .Deterministic (
468+ "state_cov" , pt .diag (state_cov_diag ), dims = mod .param_dims ["state_cov" ]
469+ )
470+
471+ # Exogenous priors
472+ if isinstance (mod .exog_state_names , list ):
473+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = mod .param_dims ["beta_exog" ])
474+ elif isinstance (mod .exog_state_names , dict ):
475+ for name in mod .exog_state_names :
476+ if mod .exog_state_names .get (name ):
477+ pm .Normal (f"beta_{ name } " , mu = 0 , sigma = 1 , dims = mod .param_dims [f"beta_{ name } " ])
478+
479+ mod .build_statespace_graph (data = df )
480+
481+ with freeze_dims_and_data (m ):
482+ prior = pm .sample_prior_predictive (
483+ draws = 10 , random_seed = rng , compile_kwargs = {"mode" : "JAX" }
484+ )
485+
486+ prior_cond = mod .sample_conditional_prior (prior , mvn_method = "eigh" )
487+ beta_dot_data = prior_cond .filtered_prior_observed .values - prior_cond .filtered_prior .values
488+
489+ if isinstance (exog_state_names , list ) or k_exog is not None :
490+ beta = prior .prior .beta_exog
491+ assert beta .shape == (1 , 10 , 3 , 2 )
492+
493+ np .testing .assert_allclose (
494+ beta_dot_data ,
495+ np .einsum ("tx,...sx->...ts" , exog_data ["exogenous_data" ].values , beta ),
496+ atol = 1e-2 ,
497+ )
498+
499+ elif isinstance (exog_state_names , dict ):
500+ assert prior .prior .beta_y1 .shape == (1 , 10 , 2 )
501+ assert prior .prior .beta_y2 .shape == (1 , 10 , 1 )
502+
503+ obs_intercept = [
504+ np .einsum ("tx,...x->...t" , exog_data [f"{ name } _exogenous_data" ].values , beta )
505+ for name , beta in zip (["y1" , "y2" ], [prior .prior .beta_y1 , prior .prior .beta_y2 ])
506+ ]
507+
508+ # y3 has no exogenous variables
509+ obs_intercept .append (np .zeros_like (obs_intercept [0 ]))
510+
511+ np .testing .assert_allclose (beta_dot_data , np .stack (obs_intercept , axis = - 1 ), atol = 1e-2 )
0 commit comments