Skip to content

Commit a898eb6

Browse files
committed
Allow multiple observed in Cycle component
1 parent 480f4fb commit a898eb6

File tree

3 files changed

+211
-22
lines changed

3 files changed

+211
-22
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 120 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pytensor import tensor as pt
4+
from scipy import linalg
45

56
from pymc_extras.statespace.models.structural.core import Component
67
from pymc_extras.statespace.models.structural.utils import _frequency_transition_block
@@ -10,6 +11,10 @@ class CycleComponent(Component):
1011
r"""
1112
A component for modeling longer-term cyclical effects
1213
14+
Supports both univariate and multivariate time series. For multivariate time series,
15+
each endogenous variable gets its own independent cycle component with separate
16+
cosine/sine states and optional variable-specific innovation variances.
17+
1318
Parameters
1419
----------
1520
name: str
@@ -32,6 +37,11 @@ class CycleComponent(Component):
3237
innovations: bool, default True
3338
Whether to include stochastic innovations in the strength of the seasonal effect. If True, an additional
3439
parameter, ``sigma_{name}`` will be added to the model.
40+
For multivariate time series, this is a vector (variable-specific innovation variances).
41+
42+
observed_state_names: list[str], optional
43+
Names of the observed state variables. For univariate time series, defaults to ``["data"]``.
44+
For multivariate time series, specify a list of names for each endogenous variable.
3545
3646
Notes
3747
-----
@@ -51,8 +61,16 @@ class CycleComponent(Component):
5161
5262
Unlike a FrequencySeasonality component, the length of a CycleComponent can be estimated.
5363
64+
**Multivariate Support:**
65+
For multivariate time series with k endogenous variables, the component creates:
66+
- 2k states (cosine and sine components for each variable)
67+
- Block diagonal transition and selection matrices
68+
- Variable-specific innovation variances (optional)
69+
- Proper parameter shapes: (k, 2) for initial states, (k,) for innovation variances
70+
5471
Examples
5572
--------
73+
**Univariate Example:**
5674
Estimate a business cycle with length between 6 and 12 years:
5775
5876
.. code:: python
@@ -84,6 +102,35 @@ class CycleComponent(Component):
84102
85103
idata = pm.sample(nuts_sampler='numpyro')
86104
105+
**Multivariate Example:**
106+
Model cycles for multiple economic indicators with variable-specific innovation variances:
107+
108+
.. code:: python
109+
110+
# Multivariate cycle component
111+
cycle = st.CycleComponent(
112+
name='business_cycle',
113+
cycle_length=12,
114+
estimate_cycle_length=False,
115+
innovations=True,
116+
dampen=True,
117+
observed_state_names=['gdp', 'unemployment', 'inflation']
118+
)
119+
120+
# Build the model
121+
ss_mod = cycle.build()
122+
123+
# In PyMC model:
124+
with pm.Model(coords=ss_mod.coords) as model:
125+
# Initial states: shape (3, 2) for 3 variables, 2 states each
126+
cycle_init = pm.Normal('business_cycle', dims=('business_cycle_endog', 'business_cycle_state'))
127+
128+
# Dampening factor: scalar (shared across variables)
129+
dampening = pm.Uniform('business_cycle_dampening_factor', lower=0.8, upper=1.0)
130+
131+
# Innovation variances: shape (3,) for variable-specific variances
132+
sigma_cycle = pm.HalfNormal('sigma_business_cycle', dims=('business_cycle_endog',))
133+
87134
References
88135
----------
89136
.. [1] Durbin, James, and Siem Jan Koopman. 2012.
@@ -137,14 +184,23 @@ def __init__(
137184
)
138185

139186
def make_symbolic_graph(self) -> None:
140-
self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
141-
self.ssm["selection", :, :] = np.eye(self.k_states)
142-
self.param_dims = {self.name: (f"{self.name}_state",)}
143-
self.coords = {f"{self.name}_state": self.state_names}
187+
if self.k_endog == 1:
188+
self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
189+
self.ssm["selection", :, :] = np.eye(self.k_states)
190+
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,))
191+
192+
else:
193+
Z = np.array([1.0, 0.0]).reshape((1, -1))
194+
design_matrix = linalg.block_diag(*[Z for _ in range(self.k_endog)])
195+
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
144196

145-
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,))
197+
R = np.eye(2) # 2x2 identity for each cycle component
198+
selection_matrix = linalg.block_diag(*[R for _ in range(self.k_endog)])
199+
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
146200

147-
self.ssm["initial_state", :] = init_state
201+
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_endog, 2))
202+
203+
self.ssm["initial_state", :] = init_state.ravel()
148204

149205
if self.estimate_cycle_length:
150206
lamb = self.make_and_register_variable(f"{self.name}_length", shape=())
@@ -157,23 +213,59 @@ def make_symbolic_graph(self) -> None:
157213
rho = 1
158214

159215
T = rho * _frequency_transition_block(lamb, j=1)
160-
self.ssm["transition", :, :] = T
216+
if self.k_endog == 1:
217+
self.ssm["transition", :, :] = T
218+
else:
219+
# can't make the linalg.block_diag logic work here
220+
# doing it manually for now
221+
for i in range(self.k_endog):
222+
start_idx = i * 2
223+
end_idx = (i + 1) * 2
224+
self.ssm["transition", start_idx:end_idx, start_idx:end_idx] = T
161225

162226
if self.innovations:
163-
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
164-
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
227+
if self.k_endog == 1:
228+
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
229+
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
230+
else:
231+
sigma_cycle = self.make_and_register_variable(
232+
f"sigma_{self.name}", shape=(self.k_endog,)
233+
)
234+
# can't make the linalg.block_diag logic work here
235+
# doing it manually for now
236+
for i in range(self.k_endog):
237+
start_idx = i * 2
238+
end_idx = (i + 1) * 2
239+
Q_block = pt.eye(2) * sigma_cycle[i] ** 2
240+
self.ssm["state_cov", start_idx:end_idx, start_idx:end_idx] = Q_block
165241

166242
def populate_component_properties(self):
167243
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
168244
self.param_names = [f"{self.name}"]
169245

170-
self.param_info = {
171-
f"{self.name}": {
172-
"shape": (2,),
173-
"constraints": None,
174-
"dims": (f"{self.name}_state",),
246+
if self.k_endog == 1:
247+
self.param_dims = {self.name: (f"{self.name}_state",)}
248+
self.coords = {f"{self.name}_state": self.state_names}
249+
self.param_info = {
250+
f"{self.name}": {
251+
"shape": (2,),
252+
"constraints": None,
253+
"dims": (f"{self.name}_state",),
254+
}
255+
}
256+
else:
257+
self.param_dims = {self.name: (f"{self.name}_endog", f"{self.name}_state")}
258+
self.coords = {
259+
f"{self.name}_state": self.state_names,
260+
f"{self.name}_endog": self.observed_state_names,
261+
}
262+
self.param_info = {
263+
f"{self.name}": {
264+
"shape": (self.k_endog, 2),
265+
"constraints": None,
266+
"dims": (f"{self.name}_endog", f"{self.name}_state"),
267+
}
175268
}
176-
}
177269

178270
if self.estimate_cycle_length:
179271
self.param_names += [f"{self.name}_length"]
@@ -193,9 +285,17 @@ def populate_component_properties(self):
193285

194286
if self.innovations:
195287
self.param_names += [f"sigma_{self.name}"]
196-
self.param_info[f"sigma_{self.name}"] = {
197-
"shape": (),
198-
"constraints": "Positive",
199-
"dims": None,
200-
}
288+
if self.k_endog == 1:
289+
self.param_info[f"sigma_{self.name}"] = {
290+
"shape": (),
291+
"constraints": "Positive",
292+
"dims": None,
293+
}
294+
else:
295+
self.param_dims[f"sigma_{self.name}"] = (f"{self.name}_endog",)
296+
self.param_info[f"sigma_{self.name}"] = {
297+
"shape": (self.k_endog,),
298+
"constraints": "Positive",
299+
"dims": (f"{self.name}_endog",),
300+
}
201301
self.shock_names = self.state_names.copy()

tests/statespace/models/structural/components/test_cycle.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,96 @@ def test_cycle_component_with_innovations_and_cycle_length(rng):
4545
"cycle_dampening_factor": 0.95,
4646
"sigma_cycle": 1.0,
4747
}
48+
x, y = simulate_from_numpy_model(cycle, rng, params)
49+
50+
cycle.build(verbose=False)
51+
_assert_basic_coords_correct(cycle)
52+
53+
54+
def test_cycle_multivariate_deterministic(rng):
55+
"""Test multivariate cycle component with deterministic cycles."""
56+
cycle = st.CycleComponent(
57+
name="cycle",
58+
cycle_length=12,
59+
estimate_cycle_length=False,
60+
innovations=False,
61+
observed_state_names=["data_1", "data_2", "data_3"],
62+
)
63+
params = {"cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX)}
64+
x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
65+
66+
# Check that each variable has a cyclical pattern with the expected period
67+
for i in range(3):
68+
assert_pattern_repeats(y[:, i], 12, atol=ATOL, rtol=RTOL)
69+
70+
# Check that the cycles have different amplitudes (different initial states)
71+
assert np.std(y[:, 0]) > 0
72+
assert np.std(y[:, 1]) > 0
73+
assert np.std(y[:, 2]) > 0
74+
# The second and third variables should have larger amplitudes due to larger initial states
75+
assert np.std(y[:, 1]) > np.std(y[:, 0])
76+
assert np.std(y[:, 2]) > np.std(y[:, 0])
77+
4878

79+
def test_cycle_multivariate_with_dampening(rng):
80+
"""Test multivariate cycle component with dampening."""
81+
cycle = st.CycleComponent(
82+
name="cycle",
83+
cycle_length=12,
84+
estimate_cycle_length=False,
85+
innovations=False,
86+
dampen=True,
87+
observed_state_names=["data_1", "data_2", "data_3"],
88+
)
89+
params = {
90+
"cycle": np.array([[10.0, 10.0], [20.0, 20.0], [30.0, 30.0]], dtype=config.floatX),
91+
"cycle_dampening_factor": 0.75,
92+
}
93+
x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
94+
95+
# Check that all cycles dampen to zero over time
96+
for i in range(3):
97+
assert_allclose(y[-1, i], 0.0, atol=ATOL, rtol=RTOL)
98+
99+
# Check that the dampening pattern is consistent across variables
100+
# The variables should dampen at the same rate but with different initial amplitudes
101+
for i in range(1, 3):
102+
# The ratio of final to initial values should be similar across variables
103+
ratio_0 = abs(y[-1, 0] / y[0, 0]) if y[0, 0] != 0 else 0
104+
ratio_i = abs(y[-1, i] / y[0, i]) if y[0, i] != 0 else 0
105+
assert_allclose(ratio_0, ratio_i, atol=1e-2, rtol=1e-2)
106+
107+
108+
def test_cycle_multivariate_with_innovations_and_cycle_length(rng):
109+
"""Test multivariate cycle component with innovations and estimated cycle length."""
110+
cycle = st.CycleComponent(
111+
name="cycle",
112+
estimate_cycle_length=True,
113+
innovations=True,
114+
dampen=True,
115+
observed_state_names=["data_1", "data_2", "data_3"],
116+
)
117+
params = {
118+
"cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX),
119+
"cycle_length": 12.0,
120+
"cycle_dampening_factor": 0.95,
121+
"sigma_cycle": np.array([0.5, 1.0, 1.5]), # Different innovation variances per variable
122+
}
49123
x, y = simulate_from_numpy_model(cycle, rng, params)
50124

51125
cycle.build(verbose=False)
52126
_assert_basic_coords_correct(cycle)
127+
128+
assert cycle.coords["cycle_state"] == ["cycle_Cos", "cycle_Sin"]
129+
assert cycle.coords["cycle_endog"] == ["data_1", "data_2", "data_3"]
130+
131+
assert cycle.k_endog == 3
132+
assert cycle.k_states == 6 # 2 states per variable
133+
assert cycle.k_posdef == 6 # 2 innovations per variable
134+
135+
# Check that the data has the expected shape
136+
assert y.shape[1] == 3 # 3 variables
137+
138+
# Check that each variable shows some variation (due to innovations)
139+
for i in range(3):
140+
assert np.std(y[:, i]) > 0

tests/statespace/models/structural/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ def _assert_basic_coords_correct(mod):
2323
assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names
2424
assert mod.coords[SHOCK_DIM] == mod.shock_names
2525
assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names
26-
assert mod.coords[OBS_STATE_DIM] == ["data"]
27-
assert mod.coords[OBS_STATE_AUX_DIM] == ["data"]
26+
expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ["data"]
27+
assert mod.coords[OBS_STATE_DIM] == expected_obs
28+
assert mod.coords[OBS_STATE_AUX_DIM] == expected_obs

0 commit comments

Comments
 (0)