2727)
2828from pymc_experimental .statespace .filters .distributions import (
2929 LinearGaussianStateSpace ,
30+ MvNormalSVD ,
3031 SequenceMvNormal ,
3132)
3233from pymc_experimental .statespace .filters .utilities import stabilize
@@ -876,9 +877,8 @@ def build_statespace_graph(
876877 cov_jitter = cov_jitter ,
877878 )
878879
879- outputs = filter_outputs
880- logp = outputs .pop (- 1 )
881- states , covs = outputs [:3 ], outputs [3 :]
880+ logp = filter_outputs .pop (- 1 )
881+ states , covs = filter_outputs [:3 ], filter_outputs [3 :]
882882 filtered_states , predicted_states , observed_states = states
883883 filtered_covariances , predicted_covariances , observed_covariances = covs
884884 if save_kalman_filter_outputs_in_idata :
@@ -2022,7 +2022,7 @@ def forecast(
20222022
20232023 with pm .Model (coords = temp_coords ) as forecast_model :
20242024 (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2025- data_dims = ["data_time" , OBS_STATE_DIM ]
2025+ data_dims = ["data_time" , OBS_STATE_DIM ],
20262026 )
20272027
20282028 group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
@@ -2038,7 +2038,7 @@ def forecast(
20382038 if scenario is not None :
20392039 sub_dict = {
20402040 forecast_model [data_name ]: pt .as_tensor_variable (
2041- scenario .get (data_name ), name = "data_var"
2041+ scenario .get (data_name ), name = data_name
20422042 )
20432043 for data_name in self .data_names
20442044 }
@@ -2185,16 +2185,16 @@ def impulse_response_function(
21852185 if use_posterior_cov :
21862186 Q = post_Q
21872187 if orthogonalize_shocks :
2188- Q = pt .linalg .cholesky (Q )
2188+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
21892189 elif shock_cov is not None :
21902190 Q = pt .as_tensor_variable (shock_cov )
21912191 if orthogonalize_shocks :
2192- Q = pt .linalg .cholesky (Q )
2192+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
21932193
21942194 if shock_trajectory is None :
21952195 shock_trajectory = pt .zeros ((n_steps , self .k_posdef ))
21962196 if Q is not None :
2197- init_shock = pm . MvNormal ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2197+ init_shock = MvNormalSVD ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
21982198 else :
21992199 init_shock = pm .Deterministic (
22002200 "initial_shock" ,
0 commit comments