1616from pytensor import Variable , graph_replace
1717from pytensor .compile import get_mode
1818
19- from pymc_experimental .statespace .core .representation import PytensorRepresentation
20- from pymc_experimental .statespace .filters import (
21- KalmanSmoother ,
22- SquareRootFilter ,
19+ from pymc_extras .statespace .core .representation import PytensorRepresentation
20+ from pymc_extras .statespace .filters import (
21+ CholeskyFilter ,
22+ SingleTimeseriesFilter ,
2323 StandardFilter ,
24+ SteadyStateFilter ,
2425 UnivariateFilter ,
2526)
26- from pymc_experimental .statespace .filters .distributions import (
27+ from pymc_extras .statespace .filters .distributions import (
2728 LinearGaussianStateSpace ,
28- MvNormalSVD ,
29- SequenceMvNormal ,
29+ LinearGaussianStateSpaceRV ,
3030)
31- from pymc_experimental .statespace .filters .utilities import stabilize
32- from pymc_experimental .statespace .utils .constants import (
33- ALL_STATE_AUX_DIM ,
34- ALL_STATE_DIM ,
35- FILTER_OUTPUT_DIMS ,
36- FILTER_OUTPUT_TYPES ,
31+ from pymc_extras .statespace .filters .utilities import stabilize
32+ from pymc_extras .statespace .utils .constants import (
3733 JITTER_DEFAULT ,
38- MATRIX_DIMS ,
39- MATRIX_NAMES ,
40- OBS_STATE_DIM ,
41- SHOCK_DIM ,
34+ LONG_MATRIX_NAMES ,
35+ MISSING_FILL ,
4236 SHORT_NAME_TO_LONG ,
43- TIME_DIM ,
44- VECTOR_VALUED ,
4537)
46- from pymc_experimental .statespace .utils .data_tools import register_data_with_pymc
38+ from pymc_extras .statespace .utils .data_tools import register_data_with_pymc
4739
4840_log = logging .getLogger ("pymc.experimental.statespace" )
4941
5042floatX = pytensor .config .floatX
5143FILTER_FACTORY = {
5244 "standard" : StandardFilter ,
5345 "univariate" : UnivariateFilter ,
54- "cholesky" : SquareRootFilter ,
46+ "cholesky" : CholeskyFilter ,
47+ "steady_state" : SteadyStateFilter ,
5548}
5649
5750
5851def _validate_filter_arg (filter_arg ):
59- if filter_arg .lower () not in FILTER_OUTPUT_TYPES :
52+ if filter_arg .lower () not in FILTER_FACTORY . keys () :
6053 raise ValueError (
61- f'filter_output should be one of { ", " .join (FILTER_OUTPUT_TYPES )} , received { filter_arg } '
54+ f'filter_output should be one of { ", " .join (FILTER_FACTORY . keys () )} , received { filter_arg } '
6255 )
6356
6457
@@ -752,7 +745,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
752745 matrices = self .unpack_statespace ()
753746
754747 registered_matrices = []
755- for i , (matrix , name ) in enumerate (zip (matrices , MATRIX_NAMES )):
748+ for i , (matrix , name ) in enumerate (zip (matrices , LONG_MATRIX_NAMES )):
756749 time_varying_ndim = 2 if name in VECTOR_VALUED else 3
757750 if not getattr (pm_mod , name , None ):
758751 shape , dims = self ._get_matrix_shape_and_dims (name )
@@ -1473,7 +1466,7 @@ def sample_statespace_matrices(
14731466 _verify_group (group )
14741467
14751468 if matrix_names is None :
1476- matrix_names = MATRIX_NAMES
1469+ matrix_names = LONG_MATRIX_NAMES
14771470 elif isinstance (matrix_names , str ):
14781471 matrix_names = [matrix_names ]
14791472
@@ -1486,7 +1479,7 @@ def sample_statespace_matrices(
14861479
14871480 self ._insert_data_variables ()
14881481 matrices = self .unpack_statespace ()
1489- for short_name , matrix in zip (MATRIX_NAMES , matrices ):
1482+ for short_name , matrix in zip (LONG_MATRIX_NAMES , matrices ):
14901483 long_name = SHORT_NAME_TO_LONG [short_name ]
14911484 if (long_name in matrix_names ) or (short_name in matrix_names ):
14921485 name = long_name if long_name in matrix_names else short_name
@@ -2040,7 +2033,7 @@ def forecast(
20402033 }
20412034
20422035 matrices = graph_replace (matrices , replace = sub_dict , strict = True )
2043- [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
2036+ [setattr (matrix , "name" , name ) for name , matrix in zip (LONG_MATRIX_NAMES [2 :], matrices )]
20442037
20452038 _ = LinearGaussianStateSpace (
20462039 "forecast" ,
0 commit comments