Skip to content

Commit dc44ac5

Browse files
committed
Build symbolic Graph from pytensor operations
1 parent 61373fe commit dc44ac5

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,14 @@ def make_symbolic_graph(self) -> None:
340340
if self.remove_first_state:
341341
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
342342
# all previous states.
343-
zero_d = np.zeros((self.duration, self.duration))
344-
id_d = np.eye(self.duration)
343+
zero_d = pt.zeros((self.duration, self.duration))
344+
id_d = pt.eye(self.duration)
345345

346-
blocks = []
346+
row_blocks = []
347347

348348
# First row: all -1_d blocks
349349
first_row = [-id_d for _ in range(self.season_length - 1)]
350-
blocks.append(first_row)
350+
row_blocks.append(pt.concatenate(first_row, axis=1))
351351

352352
# Rows 2 to season_length-1: shifted identity blocks
353353
for i in range(self.season_length - 2):
@@ -357,16 +357,15 @@ def make_symbolic_graph(self) -> None:
357357
row.append(id_d)
358358
else:
359359
row.append(zero_d)
360-
blocks.append(row)
360+
row_blocks.append(pt.concatenate(row, axis=1))
361361

362362
# Stack blocks
363-
T = np.block(blocks)
363+
T = pt.concatenate(row_blocks, axis=0)
364364
else:
365365
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
366366
# circulant matrix that cycles between the states.
367-
T = np.eye(k_states, k=1)
368-
T[-1, 0] = 1
369-
T = pt.as_tensor_variable(T)
367+
T = pt.eye(k_states, k=1)
368+
T = pt.set_subtensor(T[-1, 0], 1)
370369

371370
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
372371

0 commit comments

Comments
 (0)