Skip to content

Commit 61373fe

Browse files
committed
Handle initial states
1 parent 452f1cb commit 61373fe

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def populate_component_properties(self):
332332

333333
def make_symbolic_graph(self) -> None:
334334
k_states = self.k_states // self.k_endog
335+
duration = self.duration
336+
k_unique_states = k_states // duration
335337
k_posdef = self.k_posdef // self.k_endog
336338
k_endog = self.k_endog
337339

@@ -364,16 +366,23 @@ def make_symbolic_graph(self) -> None:
364366
# circulant matrix that cycles between the states.
365367
T = np.eye(k_states, k=1)
366368
T[-1, 0] = 1
369+
T = pt.as_tensor_variable(T)
367370

368371
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
369372

370373
Z = pt.zeros((1, k_states))[0, 0].set(1)
371374
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
372375

373376
initial_states = self.make_and_register_variable(
374-
f"coefs_{self.name}", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
377+
f"coefs_{self.name}",
378+
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
375379
)
376-
self.ssm["initial_state", :] = initial_states.ravel()
380+
if k_endog == 1:
381+
self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0)
382+
else:
383+
self.ssm["initial_state", :] = pt.extra_ops.repeat(
384+
initial_states, duration, axis=1
385+
).ravel()
377386

378387
if self.innovations:
379388
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)

0 commit comments

Comments
 (0)