Skip to content

Commit 90aa50e

Browse files
committed
Update statespace imports
1 parent 62fffa4 commit 90aa50e

File tree

9 files changed

+80
-76
lines changed

9 files changed

+80
-76
lines changed

pymc_extras/model/marginal/marginal_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525

2626
from pymc_extras.distributions import DiscreteMarkovChain
2727
from pymc_extras.model.marginal.distributions import (
28-
FiniteDiscreteMarginalRV,
29-
MarginalRV,
30-
MarginalizedRV,
28+
MarginalDiscreteMarkovChainRV,
29+
MarginalFiniteDiscreteRV,
30+
get_domain_of_finite_discrete_rv,
31+
reduce_batch_dependent_logps,
3132
)
3233
from pymc_extras.model.marginal.graph_analysis import (
33-
get_support_axes,
34-
get_support_shape,
35-
get_support_size,
36-
get_support_values,
37-
get_variable_support,
34+
find_conditional_dependent_rvs,
35+
find_conditional_input_rvs,
36+
is_conditional_dependent,
37+
subgraph_batch_dim_connection,
3838
)
3939

4040
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]

pymc_extras/statespace/core/compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import pytensor
44
import pytensor.tensor as pt
55

6-
from pymc_experimental.statespace.core import PyMCStateSpace
7-
from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace
8-
from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG
6+
from pymc_extras.statespace.core import PyMCStateSpace
7+
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
8+
from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
99

1010

1111
def compile_statespace(

pymc_extras/statespace/core/representation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import pytensor
55
import pytensor.tensor as pt
66

7-
from pymc_experimental.statespace.utils.constants import (
8-
NEVER_TIME_VARYING,
9-
VECTOR_VALUED,
7+
from pymc_extras.statespace.utils.constants import (
8+
JITTER_DEFAULT,
9+
LONG_MATRIX_NAMES,
10+
MISSING_FILL,
11+
SHORT_NAME_TO_LONG,
1012
)
1113

1214
floatX = pytensor.config.floatX

pymc_extras/statespace/core/statespace.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,42 @@
1616
from pytensor import Variable, graph_replace
1717
from pytensor.compile import get_mode
1818

19-
from pymc_experimental.statespace.core.representation import PytensorRepresentation
20-
from pymc_experimental.statespace.filters import (
21-
KalmanSmoother,
22-
SquareRootFilter,
19+
from pymc_extras.statespace.core.representation import PytensorRepresentation
20+
from pymc_extras.statespace.filters import (
21+
CholeskyFilter,
22+
SingleTimeseriesFilter,
2323
StandardFilter,
24+
SteadyStateFilter,
2425
UnivariateFilter,
2526
)
26-
from pymc_experimental.statespace.filters.distributions import (
27+
from pymc_extras.statespace.filters.distributions import (
2728
LinearGaussianStateSpace,
28-
MvNormalSVD,
29-
SequenceMvNormal,
29+
LinearGaussianStateSpaceRV,
3030
)
31-
from pymc_experimental.statespace.filters.utilities import stabilize
32-
from pymc_experimental.statespace.utils.constants import (
33-
ALL_STATE_AUX_DIM,
34-
ALL_STATE_DIM,
35-
FILTER_OUTPUT_DIMS,
36-
FILTER_OUTPUT_TYPES,
31+
from pymc_extras.statespace.filters.utilities import stabilize
32+
from pymc_extras.statespace.utils.constants import (
3733
JITTER_DEFAULT,
38-
MATRIX_DIMS,
39-
MATRIX_NAMES,
40-
OBS_STATE_DIM,
41-
SHOCK_DIM,
34+
LONG_MATRIX_NAMES,
35+
MISSING_FILL,
4236
SHORT_NAME_TO_LONG,
43-
TIME_DIM,
44-
VECTOR_VALUED,
4537
)
46-
from pymc_experimental.statespace.utils.data_tools import register_data_with_pymc
38+
from pymc_extras.statespace.utils.data_tools import register_data_with_pymc
4739

4840
_log = logging.getLogger("pymc.experimental.statespace")
4941

5042
floatX = pytensor.config.floatX
5143
FILTER_FACTORY = {
5244
"standard": StandardFilter,
5345
"univariate": UnivariateFilter,
54-
"cholesky": SquareRootFilter,
46+
"cholesky": CholeskyFilter,
47+
"steady_state": SteadyStateFilter,
5548
}
5649

5750

5851
def _validate_filter_arg(filter_arg):
59-
if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
52+
if filter_arg.lower() not in FILTER_FACTORY.keys():
6053
raise ValueError(
61-
f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}'
54+
f'filter_output should be one of {", ".join(FILTER_FACTORY.keys())}, received {filter_arg}'
6255
)
6356

6457

@@ -752,7 +745,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
752745
matrices = self.unpack_statespace()
753746

754747
registered_matrices = []
755-
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
748+
for i, (matrix, name) in enumerate(zip(matrices, LONG_MATRIX_NAMES)):
756749
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
757750
if not getattr(pm_mod, name, None):
758751
shape, dims = self._get_matrix_shape_and_dims(name)
@@ -1473,7 +1466,7 @@ def sample_statespace_matrices(
14731466
_verify_group(group)
14741467

14751468
if matrix_names is None:
1476-
matrix_names = MATRIX_NAMES
1469+
matrix_names = LONG_MATRIX_NAMES
14771470
elif isinstance(matrix_names, str):
14781471
matrix_names = [matrix_names]
14791472

@@ -1486,7 +1479,7 @@ def sample_statespace_matrices(
14861479

14871480
self._insert_data_variables()
14881481
matrices = self.unpack_statespace()
1489-
for short_name, matrix in zip(MATRIX_NAMES, matrices):
1482+
for short_name, matrix in zip(LONG_MATRIX_NAMES, matrices):
14901483
long_name = SHORT_NAME_TO_LONG[short_name]
14911484
if (long_name in matrix_names) or (short_name in matrix_names):
14921485
name = long_name if long_name in matrix_names else short_name
@@ -2040,7 +2033,7 @@ def forecast(
20402033
}
20412034

20422035
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
2043-
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
2036+
[setattr(matrix, "name", name) for name, matrix in zip(LONG_MATRIX_NAMES[2:], matrices)]
20442037

20452038
_ = LinearGaussianStateSpace(
20462039
"forecast",

pymc_extras/statespace/models/ETS.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from pytensor import graph_replace
88
from pytensor.tensor.slinalg import solve_discrete_lyapunov
99

10-
from pymc_experimental.statespace.core.statespace import PyMCStateSpace, floatX
11-
from pymc_experimental.statespace.models.utilities import make_default_coords
12-
from pymc_experimental.statespace.utils.constants import (
10+
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
11+
from pymc_extras.statespace.models.utilities import make_default_coords
12+
from pymc_extras.statespace.utils.constants import (
13+
JITTER_DEFAULT,
14+
LONG_MATRIX_NAMES,
15+
MISSING_FILL,
16+
SHORT_NAME_TO_LONG,
1317
ALL_STATE_AUX_DIM,
1418
ALL_STATE_DIM,
1519
ETS_SEASONAL_DIM,

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,18 @@
66

77
from pytensor.tensor.slinalg import solve_discrete_lyapunov
88

9-
from pymc_experimental.statespace.core.statespace import PyMCStateSpace, floatX
10-
from pymc_experimental.statespace.models.utilities import (
9+
from pymc_extras.statespace.core.statespace import PyMCStateSpace, floatX
10+
from pymc_extras.statespace.models.utilities import (
1111
make_default_coords,
12-
make_harvey_state_names,
13-
make_SARIMA_transition_matrix,
12+
make_seasonal_harmonics,
1413
)
15-
from pymc_experimental.statespace.utils.constants import (
16-
ALL_STATE_AUX_DIM,
17-
ALL_STATE_DIM,
18-
AR_PARAM_DIM,
19-
MA_PARAM_DIM,
20-
OBS_STATE_DIM,
21-
SARIMAX_STATE_STRUCTURES,
22-
SEASONAL_AR_PARAM_DIM,
23-
SEASONAL_MA_PARAM_DIM,
14+
from pymc_extras.statespace.utils.constants import (
15+
JITTER_DEFAULT,
16+
LONG_MATRIX_NAMES,
17+
MISSING_FILL,
18+
SHORT_NAME_TO_LONG,
2419
)
20+
import pymc_extras.statespace as pmss
2521

2622

2723
def _verify_order(p, d, q, P, D, Q, S):
@@ -147,7 +143,7 @@ class BayesianSARIMA(PyMCStateSpace):
147143
148144
.. code:: python
149145
150-
import pymc_experimental.statespace as pmss
146+
import pymc_extras.statespace as pmss
151147
import pymc as pm
152148
153149
ss_mod = pmss.BayesianSARIMA(order=(1, 0, 1), verbose=True)

pymc_extras/statespace/models/VARMAX.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from pytensor.tensor.slinalg import solve_discrete_lyapunov
99

10-
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
11-
from pymc_experimental.statespace.models.utilities import make_default_coords
12-
from pymc_experimental.statespace.utils.constants import (
10+
from pymc_extras.statespace.core.statespace import PyMCStateSpace
11+
from pymc_extras.statespace.models.utilities import make_default_coords
12+
from pymc_extras.statespace.utils.constants import (
1313
ALL_STATE_AUX_DIM,
1414
ALL_STATE_DIM,
1515
AR_PARAM_DIM,
@@ -18,7 +18,12 @@
1818
OBS_STATE_DIM,
1919
SHOCK_AUX_DIM,
2020
SHOCK_DIM,
21+
JITTER_DEFAULT,
22+
LONG_MATRIX_NAMES,
23+
MISSING_FILL,
24+
SHORT_NAME_TO_LONG,
2125
)
26+
import pymc_extras.statespace as pmss
2227

2328
floatX = pytensor.config.floatX
2429

@@ -110,7 +115,7 @@ class BayesianVARMAX(PyMCStateSpace):
110115
111116
.. code:: python
112117
113-
import pymc_experimental.statespace as pmss
118+
import pymc_extras.statespace as pmss
114119
import pymc as pm
115120
116121
# Create VAR Statespace Model

pymc_extras/statespace/models/structural.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313

1414
from pytensor import Variable
1515

16-
from pymc_experimental.statespace.core import PytensorRepresentation
17-
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
18-
from pymc_experimental.statespace.models.utilities import (
19-
conform_time_varying_and_time_invariant_matrices,
16+
from pymc_extras.statespace.core import PytensorRepresentation
17+
from pymc_extras.statespace.core.statespace import PyMCStateSpace
18+
from pymc_extras.statespace.models.utilities import (
2019
make_default_coords,
20+
make_seasonal_harmonics,
2121
)
22-
from pymc_experimental.statespace.utils.constants import (
23-
ALL_STATE_AUX_DIM,
24-
ALL_STATE_DIM,
25-
AR_PARAM_DIM,
22+
from pymc_extras.statespace.utils.constants import (
23+
JITTER_DEFAULT,
2624
LONG_MATRIX_NAMES,
27-
POSITION_DERIVATIVE_NAMES,
28-
TIME_DIM,
25+
MISSING_FILL,
26+
SHORT_NAME_TO_LONG,
2927
)
28+
from pymc_extras.statespace import structural as st
3029

3130
_log = logging.getLogger("pymc.experimental.statespace")
3231

@@ -1481,9 +1480,11 @@ def __init__(
14811480
k_endog=k_endog,
14821481
k_states=k_states,
14831482
k_posdef=k_posdef,
1483+
state_names=self.state_names,
14841484
measurement_error=False,
14851485
combine_hidden_states=True,
1486-
obs_state_idxs=obs_state_idx,
1486+
exog_names=[f"data_{name}"],
1487+
obs_state_idxs=np.ones(k_states),
14871488
)
14881489

14891490
def make_symbolic_graph(self) -> None:

pymc_extras/statespace/models/utilities.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytensor.tensor as pt
33

4-
from pymc_experimental.statespace.utils.constants import (
4+
from pymc_extras.statespace.utils.constants import (
55
ALL_STATE_AUX_DIM,
66
ALL_STATE_DIM,
77
LONG_MATRIX_NAMES,
@@ -11,6 +11,9 @@
1111
SHOCK_AUX_DIM,
1212
SHOCK_DIM,
1313
VECTOR_VALUED,
14+
JITTER_DEFAULT,
15+
MISSING_FILL,
16+
SHORT_NAME_TO_LONG,
1417
)
1518

1619

@@ -233,8 +236,8 @@ def make_SARIMA_transition_matrix(
233236
0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}
234237
235238
When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the
236-
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences,
237-
as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
239+
highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA
240+
differences, as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA
238241
differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites
239242
of the states is left an exercise for the motivated reader:
240243

0 commit comments

Comments
 (0)