27
27
)
28
28
from pymc_experimental .statespace .filters .distributions import (
29
29
LinearGaussianStateSpace ,
30
+ MvNormalSVD ,
30
31
SequenceMvNormal ,
31
32
)
32
33
from pymc_experimental .statespace .filters .utilities import stabilize
@@ -876,9 +877,8 @@ def build_statespace_graph(
876
877
cov_jitter = cov_jitter ,
877
878
)
878
879
879
- outputs = filter_outputs
880
- logp = outputs .pop (- 1 )
881
- states , covs = outputs [:3 ], outputs [3 :]
880
+ logp = filter_outputs .pop (- 1 )
881
+ states , covs = filter_outputs [:3 ], filter_outputs [3 :]
882
882
filtered_states , predicted_states , observed_states = states
883
883
filtered_covariances , predicted_covariances , observed_covariances = covs
884
884
if save_kalman_filter_outputs_in_idata :
@@ -2022,7 +2022,7 @@ def forecast(
2022
2022
2023
2023
with pm .Model (coords = temp_coords ) as forecast_model :
2024
2024
(_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2025
- data_dims = ["data_time" , OBS_STATE_DIM ]
2025
+ data_dims = ["data_time" , OBS_STATE_DIM ],
2026
2026
)
2027
2027
2028
2028
group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
@@ -2038,7 +2038,7 @@ def forecast(
2038
2038
if scenario is not None :
2039
2039
sub_dict = {
2040
2040
forecast_model [data_name ]: pt .as_tensor_variable (
2041
- scenario .get (data_name ), name = "data_var"
2041
+ scenario .get (data_name ), name = data_name
2042
2042
)
2043
2043
for data_name in self .data_names
2044
2044
}
@@ -2185,16 +2185,16 @@ def impulse_response_function(
2185
2185
if use_posterior_cov :
2186
2186
Q = post_Q
2187
2187
if orthogonalize_shocks :
2188
- Q = pt .linalg .cholesky (Q )
2188
+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
2189
2189
elif shock_cov is not None :
2190
2190
Q = pt .as_tensor_variable (shock_cov )
2191
2191
if orthogonalize_shocks :
2192
- Q = pt .linalg .cholesky (Q )
2192
+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
2193
2193
2194
2194
if shock_trajectory is None :
2195
2195
shock_trajectory = pt .zeros ((n_steps , self .k_posdef ))
2196
2196
if Q is not None :
2197
- init_shock = pm . MvNormal ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2197
+ init_shock = MvNormalSVD ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2198
2198
else :
2199
2199
init_shock = pm .Deterministic (
2200
2200
"initial_shock" ,
0 commit comments