2727)
2828from pymc_experimental .statespace .filters .distributions import (
2929 LinearGaussianStateSpace ,
30+ MvNormalSVD ,
3031 SequenceMvNormal ,
3132)
3233from pymc_experimental .statespace .filters .utilities import stabilize
@@ -864,9 +865,8 @@ def build_statespace_graph(
864865 cov_jitter = cov_jitter ,
865866 )
866867
867- outputs = filter_outputs
868- logp = outputs .pop (- 1 )
869- states , covs = outputs [:3 ], outputs [3 :]
868+ logp = filter_outputs .pop (- 1 )
869+ states , covs = filter_outputs [:3 ], filter_outputs [3 :]
870870 filtered_states , predicted_states , observed_states = states
871871 filtered_covariances , predicted_covariances , observed_covariances = covs
872872 if save_kalman_filter_outputs_in_idata :
@@ -2010,7 +2010,7 @@ def forecast(
20102010
20112011 with pm .Model (coords = temp_coords ) as forecast_model :
20122012 (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2013- data_dims = ["data_time" , OBS_STATE_DIM ]
2013+ data_dims = ["data_time" , OBS_STATE_DIM ],
20142014 )
20152015
20162016 group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
@@ -2026,7 +2026,7 @@ def forecast(
20262026 if scenario is not None :
20272027 sub_dict = {
20282028 forecast_model [data_name ]: pt .as_tensor_variable (
2029- scenario .get (data_name ), name = "data_var"
2029+ scenario .get (data_name ), name = data_name
20302030 )
20312031 for data_name in self .data_names
20322032 }
@@ -2173,16 +2173,16 @@ def impulse_response_function(
21732173 if use_posterior_cov :
21742174 Q = post_Q
21752175 if orthogonalize_shocks :
2176- Q = pt .linalg .cholesky (Q )
2176+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
21772177 elif shock_cov is not None :
21782178 Q = pt .as_tensor_variable (shock_cov )
21792179 if orthogonalize_shocks :
2180- Q = pt .linalg .cholesky (Q )
2180+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
21812181
21822182 if shock_trajectory is None :
21832183 shock_trajectory = pt .zeros ((n_steps , self .k_posdef ))
21842184 if Q is not None :
2185- init_shock = pm . MvNormal ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2185+ init_shock = MvNormalSVD ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
21862186 else :
21872187 init_shock = pm .Deterministic (
21882188 "initial_shock" ,
0 commit comments