@@ -67,6 +67,11 @@ def _assert_coord_shapes_match_matrices(mod, params):
67
67
n_shocks = max (1 , len (mod .coords [SHOCK_DIM ]))
68
68
n_obs = len (mod .coords [OBS_STATE_DIM ])
69
69
70
+ print (f"{ mod .coords [ALL_STATE_DIM ] = } " )
71
+ print (f"{ mod .coords [SHOCK_DIM ] = } " )
72
+ print (f"{ mod .coords [OBS_STATE_DIM ] = } " )
73
+ print (f"{ R = } " )
74
+
70
75
assert x0 .shape [- 1 :] == (
71
76
n_states ,
72
77
), f"x0 expected to have shape (n_states, ), found { x0 .shape [- 1 :]} "
@@ -104,12 +109,12 @@ def _assert_keys_match(test_dict, expected_dict):
104
109
expected_keys = list (expected_dict .keys ())
105
110
param_keys = list (test_dict .keys ())
106
111
key_diff = set (expected_keys ) - set (param_keys )
107
- assert len (key_diff ) == 0 , f' { ", " .join (key_diff )} were not found in the test_dict keys.'
112
+ assert len (key_diff ) == 0 , f" { ', ' .join (key_diff )} were not found in the test_dict keys."
108
113
109
114
key_diff = set (param_keys ) - set (expected_keys )
110
115
assert (
111
116
len (key_diff ) == 0
112
- ), f' { ", " .join (key_diff )} were keys of the tests_dict not in expected_dict.'
117
+ ), f" { ', ' .join (key_diff )} were keys of the tests_dict not in expected_dict."
113
118
114
119
115
120
def _assert_param_dims_correct (param_dims , expected_dims ):
@@ -296,8 +301,8 @@ def create_structural_model_and_equivalent_statsmodel(
296
301
if seasonal is not None :
297
302
state_names = [f"seasonal_{ i } " for i in range (seasonal )][1 :]
298
303
seasonal_coefs = rng .normal (size = (seasonal - 1 ,)).astype (floatX )
299
- params ["coefs_seasonal " ] = seasonal_coefs
300
- expected_param_dims ["coefs_seasonal " ] += ("state_seasonal" ,)
304
+ params ["params_seasonal " ] = seasonal_coefs
305
+ expected_param_dims ["params_seasonal " ] += ("state_seasonal" ,)
301
306
302
307
expected_coords ["state_seasonal" ] += tuple (state_names )
303
308
expected_coords [ALL_STATE_DIM ] += state_names
@@ -335,8 +340,8 @@ def create_structural_model_and_equivalent_statsmodel(
335
340
336
341
seasonal_params = rng .normal (size = n_states ).astype (floatX )
337
342
338
- params [f"seasonal_ { s } " ] = seasonal_params
339
- expected_param_dims [f"seasonal_ { s } " ] += (f"state_seasonal_{ s } " ,)
343
+ params [f"params_seasonal_ { s } " ] = seasonal_params
344
+ expected_param_dims [f"params_seasonal_ { s } " ] += (f"state_seasonal_{ s } " ,)
340
345
expected_coords [ALL_STATE_DIM ] += state_names
341
346
expected_coords [ALL_STATE_AUX_DIM ] += state_names
342
347
expected_coords [f"state_seasonal_{ s } " ] += (
@@ -404,7 +409,7 @@ def create_structural_model_and_equivalent_statsmodel(
404
409
components .append (comp )
405
410
406
411
if autoregressive is not None :
407
- ar_names = [f"L{ i + 1 } " for i in range (autoregressive )]
412
+ ar_names = [f"L{ i + 1 } " for i in range (autoregressive )]
408
413
params_ar = rng .normal (size = (autoregressive ,)).astype (floatX )
409
414
if autoregressive == 1 :
410
415
params_ar = params_ar .item ()
@@ -421,8 +426,8 @@ def create_structural_model_and_equivalent_statsmodel(
421
426
422
427
sm_params ["sigma2.ar" ] = sigma2
423
428
for i , rho in enumerate (params_ar ):
424
- sm_init [f"ar.L{ i + 1 } " ] = 0
425
- sm_params [f"ar.L{ i + 1 } " ] = rho
429
+ sm_init [f"ar.L{ i + 1 } " ] = 0
430
+ sm_params [f"ar.L{ i + 1 } " ] = rho
426
431
427
432
comp = st .AutoregressiveComponent (name = "ar" , order = autoregressive )
428
433
components .append (comp )
@@ -439,7 +444,7 @@ def create_structural_model_and_equivalent_statsmodel(
439
444
440
445
for i , beta in enumerate (betas ):
441
446
sm_params [f"beta.x{ i + 1 } " ] = beta
442
- sm_init [f"beta.x{ i + 1 } " ] = beta
447
+ sm_init [f"beta.x{ i + 1 } " ] = beta
443
448
comp = st .RegressionComponent (name = "exog" , state_names = names )
444
449
components .append (comp )
445
450
0 commit comments