Skip to content

Commit ab98abe

Browse files
committed
Improve cycle code with Jesse's feedback
1 parent f584e79 commit ab98abe

File tree

1 file changed

+20
-31
lines changed
  • pymc_extras/statespace/models/structural/components

1 file changed

+20
-31
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pytensor import tensor as pt
44
from pytensor.tensor.slinalg import block_diag
5-
from scipy import linalg
65

76
from pymc_extras.statespace.models.structural.core import Component
87
from pymc_extras.statespace.models.structural.utils import _frequency_transition_block
@@ -190,22 +189,17 @@ def __init__(
190189
)
191190

192191
def make_symbolic_graph(self) -> None:
193-
if self.k_endog == 1:
194-
self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
195-
self.ssm["selection", :, :] = np.eye(self.k_states)
196-
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,))
197-
198-
else:
199-
Z = np.array([1.0, 0.0]).reshape((1, -1))
200-
design_matrix = linalg.block_diag(*[Z for _ in range(self.k_endog)])
201-
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
192+
Z = np.array([1.0, 0.0]).reshape((1, -1))
193+
design_matrix = block_diag(*[Z for _ in range(self.k_endog)])
194+
self.ssm["design", :, :] = pt.as_tensor_variable(design_matrix)
202195

203-
R = np.eye(2) # 2x2 identity for each cycle component
204-
selection_matrix = linalg.block_diag(*[R for _ in range(self.k_endog)])
205-
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
206-
207-
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_endog, 2))
196+
R = np.eye(2) # 2x2 identity for each cycle component
197+
selection_matrix = block_diag(*[R for _ in range(self.k_endog)])
198+
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
208199

200+
init_state = self.make_and_register_variable(
201+
f"{self.name}", shape=(self.k_endog, 2) if self.k_endog > 1 else (self.k_states,)
202+
)
209203
self.ssm["initial_state", :] = init_state.ravel()
210204

211205
if self.estimate_cycle_length:
@@ -219,11 +213,8 @@ def make_symbolic_graph(self) -> None:
219213
rho = 1
220214

221215
T = rho * _frequency_transition_block(lamb, j=1)
222-
if self.k_endog == 1:
223-
self.ssm["transition", :, :] = T
224-
else:
225-
transition = block_diag(*[T for _ in range(self.k_endog)])
226-
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
216+
transition = block_diag(*[T for _ in range(self.k_endog)])
217+
self.ssm["transition"] = pt.specify_shape(transition, (self.k_states, self.k_states))
227218

228219
if self.innovations:
229220
if self.k_endog == 1:
@@ -239,13 +230,11 @@ def make_symbolic_graph(self) -> None:
239230
self.ssm["state_cov"] = pt.specify_shape(state_cov, (self.k_states, self.k_states))
240231

241232
def populate_component_properties(self):
242-
if self.k_endog == 1:
243-
self.state_names = [f"{self.name}_{f}" for f in ["Cos", "Sin"]]
244-
else:
245-
# For multivariate cycles, create state names for each observed state
246-
self.state_names = []
247-
for var_name in self.observed_state_names:
248-
self.state_names.extend([f"{self.name}_{var_name}_{f}" for f in ["Cos", "Sin"]])
233+
self.state_names = [
234+
f"{self.name}_{f}[{var_name}]" if self.k_endog > 1 else f"{self.name}_{f}"
235+
for var_name in self.observed_state_names
236+
for f in ["Cos", "Sin"]
237+
]
249238

250239
self.param_names = [f"{self.name}"]
251240

@@ -276,17 +265,17 @@ def populate_component_properties(self):
276265
if self.estimate_cycle_length:
277266
self.param_names += [f"{self.name}_length"]
278267
self.param_info[f"{self.name}_length"] = {
279-
"shape": (),
268+
"shape": () if self.k_endog == 1 else (self.k_endog,),
280269
"constraints": "Positive, non-zero",
281-
"dims": None,
270+
"dims": None if self.k_endog == 1 else f"{self.name}_endog",
282271
}
283272

284273
if self.dampen:
285274
self.param_names += [f"{self.name}_dampening_factor"]
286275
self.param_info[f"{self.name}_dampening_factor"] = {
287-
"shape": (),
276+
"shape": () if self.k_endog == 1 else (self.k_endog,),
288277
"constraints": "0 < x ≤ 1",
289-
"dims": None,
278+
"dims": None if self.k_endog == 1 else f"{self.name}_endog",
290279
}
291280

292281
if self.innovations:

0 commit comments

Comments
 (0)