Skip to content

Commit 7252d46

Browse files
committed
Small adjustments and refactoring after code review
1 parent ca1a86e commit 7252d46

File tree

2 files changed

+92
-56
lines changed

2 files changed

+92
-56
lines changed

pymc_extras/statespace/models/DFM.py

Lines changed: 91 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections.abc import Sequence
22
from typing import Any
33

4-
import numpy as np
54
import pytensor
65
import pytensor.tensor as pt
76

@@ -442,27 +441,25 @@ def state_names(self) -> list[str]:
442441
Returns the names of the hidden states: first factor states (with lags),
443442
idiosyncratic error states (with lags), then exogenous states.
444443
"""
445-
names = []
446-
447-
for i in range(self.k_factors):
448-
for lag in range(max(self.factor_order, 1)):
449-
names.append(f"L{lag}.factor_{i}")
444+
names = [
445+
f"L{lag}.factor_{i}"
446+
for i in range(self.k_factors)
447+
for lag in range(max(self.factor_order, 1))
448+
]
450449

451450
if self.error_order > 0:
452-
for i in range(self.k_endog):
453-
for lag in range(self.error_order):
454-
names.append(f"L{lag}.error_{i}")
451+
names.extend(
452+
f"L{lag}.error_{i}" for i in range(self.k_endog) for lag in range(self.error_order)
453+
)
455454

456455
if self.exog_flag:
457456
if self.shared_exog_states:
458457
names.extend([f"beta_{exog_name}[shared]" for exog_name in self.exog_names])
459458
else:
460459
names.extend(
461-
[
462-
f"beta_{exog_name}[{endog_name}]"
463-
for exog_name in self.exog_names
464-
for endog_name in self.endog_names
465-
]
460+
f"beta_{exog_name}[{endog_name}]"
461+
for exog_name in self.exog_names
462+
for endog_name in self.endog_names
466463
)
467464
return names
468465

@@ -494,24 +491,21 @@ def coords(self) -> dict[str, Sequence]:
494491
return coords
495492

496493
@property
497-
def shock_names(self):
498-
shock_names = []
499-
500-
for i in range(self.k_factors):
501-
shock_names.append(f"factor_shock_{i}")
494+
def shock_names(self) -> list[str]:
495+
shock_names = [f"factor_shock_{i}" for i in range(self.k_factors)]
502496

503497
if self.error_order > 0:
504-
for i in range(self.k_endog):
505-
shock_names.append(f"error_shock_{i}")
498+
shock_names.extend(f"error_shock_{i}" for i in range(self.k_endog))
506499

507500
if self.exog_flag:
508501
if self.shared_exog_states:
509-
for i in range(self.k_exog):
510-
shock_names.append(f"exog_shock_{i}.shared")
502+
shock_names.extend(f"exog_shock_{i}.shared" for i in range(self.k_exog))
511503
else:
512-
for i in range(self.k_exog):
513-
for j in range(self.k_endog):
514-
shock_names.append(f"exog_shock_{i}.endog_{j}")
504+
shock_names.extend(
505+
f"exog_shock_{i}.endog_{j}"
506+
for i in range(self.k_exog)
507+
for j in range(self.k_endog)
508+
)
515509

516510
return shock_names
517511

@@ -535,7 +529,7 @@ def param_dims(self):
535529
coord_map["error_sigma"] = (OBS_STATE_DIM,)
536530

537531
if self.error_cov_type == "unstructured":
538-
coord_map["error_sigma"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
532+
coord_map["error_cov"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
539533

540534
if self.measurement_error:
541535
coord_map["sigma_obs"] = (OBS_STATE_DIM,)
@@ -584,57 +578,102 @@ def make_symbolic_graph(self):
584578
)
585579
self.ssm["initial_state_cov", :, :] = P0
586580

587-
# Design matrix
581+
# Design matrix (Z)
582+
# Construction with block structure:
583+
# When factor_order <= 1 and error_order = 0:
584+
# [ A ] A is the factor loadings matrix with shape (k_endog, k_factors)
585+
#
586+
# When factor_order > 1, add block of zeros for the factors lags:
587+
# [ A | 0 ] the zero block has shape (k_endog, k_factors * (factor_order - 1))
588+
#
589+
# When error_order > 0, add identity matrix and additional zero block for errors lags:
590+
# [ A | 0 | I | 0 ] I is the identity matrix (k_endog, k_endog) and the final zero block
591+
# has shape (k_endog, k_endog * (error_order - 1))
592+
#
593+
# When exog_flag=True, exogenous data (exog_data) is included and the design
594+
# matrix becomes 3D with the first dimension indexing time:
595+
# - shared_exog_states=True: exog_data is broadcast across all endogenous series
596+
# → shape (n_timepoints, k_endog, k_exog)
597+
# - shared_exog_states=False: each endogenous series gets its own exog block
598+
# → block-diagonal structure with shape (n_timepoints, k_endog, k_exog * k_endog)
599+
# In this case, the base design matrix (factors + errors) is repeated over
600+
# time and concatenated with the exogenous block. The final design matrix
601+
# has shape (n_timepoints, k_endog, n_columns) and combines all components.
588602
factor_loadings = self.make_and_register_variable(
589603
"factor_loadings", shape=(self.k_endog, self.k_factors), dtype=floatX
590604
)
591-
605+
# Add factor loadings (A matrix)
592606
matrix_parts = [factor_loadings]
593607

594-
# Leaving space for higher-order factors
608+
# Add zero block for the factors lags when factor_order > 1
595609
if self.factor_order > 1:
596610
matrix_parts.append(
597611
pt.zeros((self.k_endog, self.k_factors * (self.factor_order - 1)), dtype=floatX)
598612
)
599-
613+
# Add identity and zero blocks for error lags when error_order > 0
600614
if self.error_order > 0:
601615
error_matrix = pt.eye(self.k_endog, dtype=floatX)
602616
matrix_parts.append(error_matrix)
603617
matrix_parts.append(
604618
pt.zeros((self.k_endog, self.k_endog * (self.error_order - 1)), dtype=floatX)
605619
)
606620
if len(matrix_parts) == 1:
607-
design_matrix = factor_loadings * 1.0
621+
design_matrix = factor_loadings * 1.0 # copy to ensure a new PyTensor variable
608622
design_matrix.name = "design"
609623
else:
610624
design_matrix = pt.concatenate(matrix_parts, axis=1)
611625
design_matrix.name = "design"
612-
626+
# Handle exogenous variables (if any)
613627
if self.exog_flag:
628+
exog_data = self.make_and_register_data("exog_data", shape=(None, self.k_exog))
614629
if self.shared_exog_states:
615-
exog_data = self.make_and_register_data("exog_data", shape=(None, self.k_exog))
630+
# Shared exogenous states: same exog data is used across all endogenous variables
631+
# Shape becomes (n_timepoints, k_endog, k_exog)
616632
Z_exog = pt.specify_shape(
617633
pt.join(1, *[pt.expand_dims(exog_data, 1) for _ in range(self.k_endog)]),
618634
(None, self.k_endog, self.k_exog),
619635
)
620-
n_timepoints = Z_exog.shape[0]
621-
design_matrix_time = pt.tile(design_matrix, (n_timepoints, 1, 1))
622636
else:
623-
exog_data = self.make_and_register_data("exog_data", shape=(None, self.k_exog))
637+
# Separate exogenous states: each endogenous variable gets its own exog block
638+
# Create block-diagonal structure and reshape to (n_timepoints, k_endog, k_exog * k_endog)
624639
Z_exog = pt.linalg.block_diag(
625640
*[pt.expand_dims(exog_data, 1) for _ in range(self.k_endog)]
626-
) # (time, k_endog, k_exog)
641+
)
627642
Z_exog = pt.specify_shape(Z_exog, (None, self.k_endog, self.k_exog * self.k_endog))
628-
# Repeat design_matrix over time dimension
629-
n_timepoints = Z_exog.shape[0]
630-
design_matrix_time = pt.tile(design_matrix, (n_timepoints, 1, 1))
631643

644+
# Repeat base design_matrix over time dimension to match exogenous time series
645+
n_timepoints = Z_exog.shape[0]
646+
design_matrix_time = pt.tile(design_matrix, (n_timepoints, 1, 1))
647+
# Concatenate the repeated design matrix with exogenous matrix along the last axis
648+
# Final shape: (n_timepoints, k_endog, n_columns + n_exog_columns)
632649
design_matrix = pt.concatenate([design_matrix_time, Z_exog], axis=2)
633650

634651
self.ssm["design"] = design_matrix
635652

636-
# Transition matrix
637-
# auxiliary function to build transition matrix block
653+
# Transition matrix (T)
654+
# Construction with block-diagonal structure:
655+
# Each latent component (factors, errors, exogenous states) contributes its own transition block,
656+
# and the full transition matrix is assembled with block_diag.
657+
#
658+
# - Factors (block A):
659+
# If factor_order > 0, the factor AR coefficients are organized into a
660+
# VAR(p) companion matrix of size (k_factors * factor_order, k_factors * factor_order).
661+
# This block shifts lagged factor states and applies AR coefficients.
662+
# If factor_order = 0, a zero matrix is used instead.
663+
#
664+
# - Errors (block B):
665+
# If error_order > 0:
666+
# * error_var=True → build a full VAR(p) companion matrix (cross-series correlations allowed).
667+
# * error_var=False → build independent AR(p) companion matrices (no cross-series effects).
668+
#
669+
# - Exogenous states (block C):
670+
# If exog_flag=True, exogenous states are either constant or follow a random walk, modeled with an identity
671+
# transition block of size (k_exog_states, k_exog_states).
672+
#
673+
# The final transition matrix is block-diagonal, combining all active components:
674+
# Transition = block_diag(Factors, Errors, Exogenous)
675+
676+
# auxiliary functions to build transition matrix block
638677
def build_var_block_matrix(ar_coeffs, k_series, p):
639678
"""
640679
Build the VAR(p) companion matrix for the factors.
@@ -648,13 +687,13 @@ def build_var_block_matrix(ar_coeffs, k_series, p):
648687
block = pt.zeros((size, size), dtype=floatX)
649688

650689
# First block row: the AR coefficient matrices for each lag
651-
block = pt.set_subtensor(block[0:k_series, 0 : k_series * p], ar_coeffs)
690+
block = block[0:k_series, 0 : k_series * p].set(ar_coeffs)
652691

653692
# Sub-diagonal identity blocks (shift structure)
654693
if p > 1:
655694
# Create the identity pattern for all sub-diagonal blocks
656695
identity_pattern = pt.eye(k_series * (p - 1), dtype=floatX)
657-
block = pt.set_subtensor(block[k_series:, : k_series * (p - 1)], identity_pattern)
696+
block = block[k_series:, : k_series * (p - 1)].set(identity_pattern)
658697

659698
return block
660699

@@ -684,7 +723,7 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
684723
return block
685724

686725
transition_blocks = []
687-
726+
# Block A: Factors
688727
if self.factor_order > 0:
689728
factor_ar = self.make_and_register_variable(
690729
"factor_ar",
@@ -696,7 +735,7 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
696735
)
697736
else:
698737
transition_blocks.append(pt.zeros((self.k_factors, self.k_factors), dtype=floatX))
699-
738+
# Block B: Errors
700739
if self.error_order > 0 and self.error_var:
701740
error_ar = self.make_and_register_variable(
702741
"error_ar", shape=(self.k_endog, self.error_order * self.k_endog), dtype=floatX
@@ -711,13 +750,13 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
711750
transition_blocks.append(
712751
build_independent_var_block_matrix(error_ar, self.k_endog, self.error_order)
713752
)
714-
# Exogenous variables are either constant or follow a random walk, so identity matrix
753+
# Block C: Exogenous states
715754
if self.exog_flag:
716755
transition_blocks.append(pt.eye(self.k_exog_states, dtype=floatX))
717756

718757
self.ssm["transition", :, :] = pt.linalg.block_diag(*transition_blocks)
719758

720-
# Selection matrix
759+
# Selection matrix (R)
721760
for i in range(self.k_factors):
722761
self.ssm["selection", i, i] = 1.0
723762

@@ -746,11 +785,8 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
746785

747786
# Handle error_sigma and error_cov depending on error_cov_type
748787
if self.error_cov_type == "scalar":
749-
base_error_sigma = self.make_and_register_variable(
750-
"error_sigma", shape=(), dtype=floatX
751-
)
752-
error_sigma = base_error_sigma * np.ones(self.k_endog, dtype=floatX)
753-
error_cov = pt.diag(error_sigma)
788+
error_sigma = self.make_and_register_variable("error_sigma", shape=(), dtype=floatX)
789+
error_cov = pt.eye(self.k_endog) * error_sigma
754790
elif self.error_cov_type == "diagonal":
755791
error_sigma = self.make_and_register_variable(
756792
"error_sigma", shape=(self.k_endog,), dtype=floatX
@@ -796,7 +832,7 @@ def build_independent_var_block_matrix(ar_coeffs, k_series, p):
796832
"sigma_obs", shape=(self.k_endog,), dtype=floatX
797833
)
798834
self.ssm["obs_cov", :, :] = pt.diag(sigma_obs)
799-
# else: obs_cov remains zero (no measurement noise and idiosyncratic noise captured in state)
835+
# else: obs_cov remains zero (no measurement noise and idiosyncratic noise captured in state)
800836
else:
801837
if self.measurement_error:
802838
# TODO: check this decision

tests/statespace/models/test_DFM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def test_dynamic_factor_ar2_error_var_unstructured(self):
537537
assert mod.coords[k] == v
538538
assert len(mod.state_names) == 7
539539
assert mod.observed_states == ["y0", "y1", "y2"]
540-
assert len(mod.shock_names) == 7
540+
assert len(mod.shock_names) == 4
541541

542542
def test_exog_shared_exog_states_exog_innovations(self):
543543
mod = BayesianDynamicFactor(

0 commit comments

Comments
 (0)