Skip to content

Commit 5026c48

Browse files
committed
Vectorized construction of transition, design, and selection matrices
1 parent db59b7d commit 5026c48

File tree

1 file changed

+35
-45
lines changed
  • pymc_extras/statespace/models

1 file changed

+35
-45
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -365,81 +365,71 @@ def param_dims(self):
365365
return coord_map
366366

367367
def make_symbolic_graph(self):
368-
# initial states
368+
# Initial states
369369
x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)
370370

371371
self.ssm["initial_state", :] = x0
372372

373-
# initial covariance
373+
# Initial covariance
374374
P0 = self.make_and_register_variable(
375375
"P0", shape=(self.k_states, self.k_states), dtype=floatX
376376
)
377377

378378
self.ssm["initial_state_cov", :, :] = P0
379379

380-
# TODO vectorize the design matrix
381380
# Design matrix
382-
self.ssm["design", :, :] = 0.0
383-
384381
factor_loadings = self.make_and_register_variable(
385382
"factor_loadings", shape=(self.k_endog, self.k_factors), dtype=floatX
386383
)
387384

388-
for i in range(self.k_endog):
389-
for j in range(self.k_factors):
390-
# Loadings for each observed variable on the latent factors
391-
self.ssm["design", i, j * self.factor_order] = factor_loadings[i, j]
392-
393-
for i in range(self.k_endog):
394-
# Loadings for each observed variable on the latent factors
395-
self.ssm["design", i, self.k_factors * self.factor_order + i * self.error_order] = 1.0
396-
397-
# TODO vectorize the transition matrix or use block matrices (reordering states, check the VAR implementation)
398-
if self.factor_order > 0:
399-
# Transition matrix
400-
factor_ar = self.make_and_register_variable(
401-
"factor_ar", shape=(self.k_factors, self.factor_order), dtype=floatX
402-
)
403-
404-
self.ssm["transition", :, :] = 0.0
385+
self.ssm["design", :, :] = 0.0
405386

406-
for j in range(self.k_factors):
407-
block_start = j * self.factor_order
408-
for i in range(self.factor_order):
409-
# Assign AR coefficients to the first row of each block
410-
self.ssm["transition", block_start, block_start + i] = factor_ar[j, i]
387+
for j in range(self.k_factors):
388+
col_idx = j * self.factor_order
389+
self.ssm["design", :, col_idx] = factor_loadings[:, j]
411390

412-
# Fill the subdiagonal with ones, only for rows 1 to p-1
413-
if i < self.factor_order - 1:
414-
self.ssm["transition", block_start + i + 1, block_start + i] = 1.0
391+
for i in range(self.k_endog):
392+
col_idx = self.k_factors * self.factor_order + i * self.error_order
393+
self.ssm["design", i, col_idx] = 1.0
394+
395+
# Transition matrix
396+
# auxiliary function to build transition matrix block
397+
def build_ar_block_matrix(ar_coeffs):
398+
# ar_coeffs: (p,)
399+
p = ar_coeffs.shape[0]
400+
top_row = pt.reshape(ar_coeffs, (1, p))
401+
below = pt.eye(p - 1, p, k=0)
402+
return pt.concatenate([top_row, below], axis=0)
403+
404+
transition_blocks = []
405+
406+
factor_ar = self.make_and_register_variable(
407+
"factor_ar", shape=(self.k_factors, self.factor_order), dtype=floatX
408+
)
409+
for j in range(self.k_factors):
410+
transition_blocks.append(build_ar_block_matrix(factor_ar[j]))
415411

416412
if self.error_order > 0:
417413
error_ar = self.make_and_register_variable(
418414
"error_ar", shape=(self.k_endog, self.error_order), dtype=floatX
419415
)
420-
421416
for j in range(self.k_endog):
422-
block_start = self.k_factors * self.factor_order + j * self.error_order
423-
for i in range(self.error_order):
424-
# Set AR coefficients for the top row of each error AR(q) block
425-
self.ssm["transition", block_start, block_start + i] = error_ar[j, i]
417+
transition_blocks.append(build_ar_block_matrix(error_ar[j]))
426418

427-
# Set subdiagonal 1.0s, except last row
428-
if i < self.error_order - 1:
429-
self.ssm["transition", block_start + i + 1, block_start + i] = 1.0
419+
# Final block diagonal transition matrix
420+
self.ssm["transition", :, :] = pt.linalg.block_diag(*transition_blocks)
430421

431-
# TODO vectorize/block matrices (reorder the states accordingly)
432422
# Selection matrix
433423
self.ssm["selection", :, :] = 0.0
424+
434425
for i in range(self.k_factors):
435-
self.ssm["selection", i * self.factor_order, i] = 1.0
426+
row = i * self.factor_order
427+
self.ssm["selection", row, i] = 1.0
436428

437429
for i in range(self.k_endog):
438-
self.ssm[
439-
"selection",
440-
self.k_factors * self.factor_order + i * self.error_order,
441-
self.k_factors + i,
442-
] = 1.0
430+
row = self.k_factors * self.factor_order + i * self.error_order
431+
col = self.k_factors + i
432+
self.ssm["selection", row, col] = 1.0
443433

444434
# State covariance matrix
445435
factor_sigma = self.make_and_register_variable(

0 commit comments

Comments
 (0)