Skip to content

Commit 152e962

Browse files
committed
Use pytensor block_diag for Cycle
1 parent 62d0750 commit 152e962

File tree

1 file changed

+21
-19
lines changed
  • pymc_extras/statespace/models/structural/components

1 file changed

+21
-19
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pytensor import tensor as pt
4+
from pytensor.tensor.slinalg import block_diag
45
from scipy import linalg
56

67
from pymc_extras.statespace.models.structural.core import Component
@@ -96,7 +97,6 @@ class CycleComponent(Component):
9697
9798
cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
9899
cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99-
100100
sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
101101
102102
ss_mod.build_statespace_graph(data)
@@ -124,13 +124,15 @@ class CycleComponent(Component):
124124
with pm.Model(coords=ss_mod.coords) as model:
125125
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
126126
# Initial states: shape (3, 2) for 3 variables, 2 states each
127-
cycle_init = pm.Normal('business_cycle', dims=('business_cycle_endog', 'business_cycle_state'))
127+
cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
128128
129129
# Dampening factor: scalar (shared across variables)
130-
dampening = pm.Uniform('business_cycle_dampening_factor', lower=0.8, upper=1.0)
130+
dampening = pm.Beta("business_cycle_dampening_factor", 2, 2)
131131
132132
# Innovation variances: shape (3,) for variable-specific variances
133-
sigma_cycle = pm.HalfNormal('sigma_business_cycle', dims=('business_cycle_endog',))
133+
sigma_cycle = pm.HalfNormal(
134+
"sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
135+
)
134136
135137
ss_mod.build_statespace_graph(data)
136138
idata = pm.sample()
@@ -220,12 +222,8 @@ def make_symbolic_graph(self) -> None:
220222
if self.k_endog == 1:
221223
self.ssm["transition", :, :] = T
222224
else:
223-
# can't make the linalg.block_diag logic work here
224-
# doing it manually for now
225-
for i in range(self.k_endog):
226-
start_idx = i * 2
227-
end_idx = (i + 1) * 2
228-
self.ssm["transition", start_idx:end_idx, start_idx:end_idx] = T
225+
transition = block_diag(*[T for _ in range(self.k_endog)])
226+
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
229227

230228
if self.innovations:
231229
if self.k_endog == 1:
@@ -235,16 +233,20 @@ def make_symbolic_graph(self) -> None:
235233
sigma_cycle = self.make_and_register_variable(
236234
f"sigma_{self.name}", shape=(self.k_endog,)
237235
)
238-
# can't make the linalg.block_diag logic work here
239-
# doing it manually for now
240-
for i in range(self.k_endog):
241-
start_idx = i * 2
242-
end_idx = (i + 1) * 2
243-
Q_block = pt.eye(2) * sigma_cycle[i] ** 2
244-
self.ssm["state_cov", start_idx:end_idx, start_idx:end_idx] = Q_block
236+
state_cov = block_diag(
237+
*[pt.eye(2) * sigma_cycle[i] ** 2 for i in range(self.k_endog)]
238+
)
239+
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
245240

246241
def populate_component_properties(self):
247-
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
242+
if self.k_endog == 1:
243+
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
244+
else:
245+
# For multivariate cycles, create state names for each observed state
246+
self.state_names = []
247+
for var_name in self.observed_state_names:
248+
self.state_names.extend([f"{self.name}_{var_name}_{f}" for f in ["Cos", "Sin"]])
249+
248250
self.param_names = [f"{self.name}"]
249251

250252
if self.k_endog == 1:
@@ -260,7 +262,7 @@ def populate_component_properties(self):
260262
else:
261263
self.param_dims = {self.name: (f"{self.name}_endog", f"{self.name}_state")}
262264
self.coords = {
263-
f"{self.name}_state": self.state_names,
265+
f"{self.name}_state": [f"{self.name}_Cos", f"{self.name}_Sin"],
264266
f"{self.name}_endog": self.observed_state_names,
265267
}
266268
self.param_info = {

0 commit comments

Comments
 (0)