@@ -67,6 +67,11 @@ def _assert_coord_shapes_match_matrices(mod, params):
6767 n_shocks = max (1 , len (mod .coords [SHOCK_DIM ]))
6868 n_obs = len (mod .coords [OBS_STATE_DIM ])
6969
70+ print (f"{ mod .coords [ALL_STATE_DIM ] = } " )
71+ print (f"{ mod .coords [SHOCK_DIM ] = } " )
72+ print (f"{ mod .coords [OBS_STATE_DIM ] = } " )
73+ print (f"{ R = } " )
74+
7075 assert x0 .shape [- 1 :] == (
7176 n_states ,
7277 ), f"x0 expected to have shape (n_states, ), found { x0 .shape [- 1 :]} "
@@ -104,12 +109,12 @@ def _assert_keys_match(test_dict, expected_dict):
104109 expected_keys = list (expected_dict .keys ())
105110 param_keys = list (test_dict .keys ())
106111 key_diff = set (expected_keys ) - set (param_keys )
107- assert len (key_diff ) == 0 , f' { ", " .join (key_diff )} were not found in the test_dict keys.'
112+ assert len (key_diff ) == 0 , f" { ', ' .join (key_diff )} were not found in the test_dict keys."
108113
109114 key_diff = set (param_keys ) - set (expected_keys )
110115 assert (
111116 len (key_diff ) == 0
112- ), f' { ", " .join (key_diff )} were keys of the tests_dict not in expected_dict.'
117+ ), f" { ', ' .join (key_diff )} were keys of the tests_dict not in expected_dict."
113118
114119
115120def _assert_param_dims_correct (param_dims , expected_dims ):
@@ -296,8 +301,8 @@ def create_structural_model_and_equivalent_statsmodel(
296301 if seasonal is not None :
297302 state_names = [f"seasonal_{ i } " for i in range (seasonal )][1 :]
298303 seasonal_coefs = rng .normal (size = (seasonal - 1 ,)).astype (floatX )
299- params ["coefs_seasonal " ] = seasonal_coefs
300- expected_param_dims ["coefs_seasonal " ] += ("state_seasonal" ,)
304+ params ["params_seasonal " ] = seasonal_coefs
305+ expected_param_dims ["params_seasonal " ] += ("state_seasonal" ,)
301306
302307 expected_coords ["state_seasonal" ] += tuple (state_names )
303308 expected_coords [ALL_STATE_DIM ] += state_names
@@ -335,8 +340,8 @@ def create_structural_model_and_equivalent_statsmodel(
335340
336341 seasonal_params = rng .normal (size = n_states ).astype (floatX )
337342
338- params [f"seasonal_ { s } " ] = seasonal_params
339- expected_param_dims [f"seasonal_ { s } " ] += (f"state_seasonal_{ s } " ,)
343+ params [f"params_seasonal_ { s } " ] = seasonal_params
344+ expected_param_dims [f"params_seasonal_ { s } " ] += (f"state_seasonal_{ s } " ,)
340345 expected_coords [ALL_STATE_DIM ] += state_names
341346 expected_coords [ALL_STATE_AUX_DIM ] += state_names
342347 expected_coords [f"state_seasonal_{ s } " ] += (
@@ -404,7 +409,7 @@ def create_structural_model_and_equivalent_statsmodel(
404409 components .append (comp )
405410
406411 if autoregressive is not None :
407- ar_names = [f"L{ i + 1 } " for i in range (autoregressive )]
412+ ar_names = [f"L{ i + 1 } " for i in range (autoregressive )]
408413 params_ar = rng .normal (size = (autoregressive ,)).astype (floatX )
409414 if autoregressive == 1 :
410415 params_ar = params_ar .item ()
@@ -421,8 +426,8 @@ def create_structural_model_and_equivalent_statsmodel(
421426
422427 sm_params ["sigma2.ar" ] = sigma2
423428 for i , rho in enumerate (params_ar ):
424- sm_init [f"ar.L{ i + 1 } " ] = 0
425- sm_params [f"ar.L{ i + 1 } " ] = rho
429+ sm_init [f"ar.L{ i + 1 } " ] = 0
430+ sm_params [f"ar.L{ i + 1 } " ] = rho
426431
427432 comp = st .AutoregressiveComponent (name = "ar" , order = autoregressive )
428433 components .append (comp )
@@ -439,7 +444,7 @@ def create_structural_model_and_equivalent_statsmodel(
439444
440445 for i , beta in enumerate (betas ):
441446 sm_params [f"beta.x{ i + 1 } " ] = beta
442- sm_init [f"beta.x{ i + 1 } " ] = beta
447+ sm_init [f"beta.x{ i + 1 } " ] = beta
443448 comp = st .RegressionComponent (name = "exog" , state_names = names )
444449 components .append (comp )
445450
0 commit comments