3434 FILTER_OUTPUT_DIMS ,
3535 FILTER_OUTPUT_TYPES ,
3636 JITTER_DEFAULT ,
37- LONG_MATRIX_NAMES ,
3837 MATRIX_DIMS ,
38+ MATRIX_NAMES ,
3939 OBS_STATE_DIM ,
4040 SHOCK_DIM ,
4141 SHORT_NAME_TO_LONG ,
@@ -750,7 +750,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
750750 matrices = self .unpack_statespace ()
751751
752752 registered_matrices = []
753- for i , (matrix , name ) in enumerate (zip (matrices , LONG_MATRIX_NAMES )):
753+ for i , (matrix , name ) in enumerate (zip (matrices , MATRIX_NAMES )):
754754 time_varying_ndim = 2 if name in VECTOR_VALUED else 3
755755 if not getattr (pm_mod , name , None ):
756756 shape , dims = self ._get_matrix_shape_and_dims (name )
@@ -1471,7 +1471,7 @@ def sample_statespace_matrices(
14711471 _verify_group (group )
14721472
14731473 if matrix_names is None :
1474- matrix_names = LONG_MATRIX_NAMES
1474+ matrix_names = MATRIX_NAMES
14751475 elif isinstance (matrix_names , str ):
14761476 matrix_names = [matrix_names ]
14771477
@@ -1484,7 +1484,7 @@ def sample_statespace_matrices(
14841484
14851485 self ._insert_data_variables ()
14861486 matrices = self .unpack_statespace ()
1487- for short_name , matrix in zip (LONG_MATRIX_NAMES , matrices ):
1487+ for short_name , matrix in zip (MATRIX_NAMES , matrices ):
14881488 long_name = SHORT_NAME_TO_LONG [short_name ]
14891489 if (long_name in matrix_names ) or (short_name in matrix_names ):
14901490 name = long_name if long_name in matrix_names else short_name
@@ -2038,10 +2038,7 @@ def forecast(
20382038 }
20392039
20402040 matrices = graph_replace (matrices , replace = sub_dict , strict = True )
2041- [
2042- setattr (matrix , "name" , name )
2043- for name , matrix in zip (LONG_MATRIX_NAMES [2 :], matrices )
2044- ]
2041+ [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
20452042
20462043 _ = LinearGaussianStateSpace (
20472044 "forecast" ,
0 commit comments