@@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(
2047
2047
2048
2048
return scenario
2049
2049
2050
+ def _build_forecast_model (
2051
+ self , time_index , t0 , forecast_index , scenario , filter_output , mvn_method
2052
+ ):
2053
+ filter_time_dim = TIME_DIM
2054
+ temp_coords = self ._fit_coords .copy ()
2055
+
2056
+ dims = None
2057
+ if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
2058
+ dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
2059
+
2060
+ t0_idx = np .flatnonzero (time_index == t0 )[0 ]
2061
+
2062
+ temp_coords ["data_time" ] = time_index
2063
+ temp_coords [TIME_DIM ] = forecast_index
2064
+
2065
+ mu_dims , cov_dims = None , None
2066
+ if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
2067
+ mu_dims = ["data_time" , ALL_STATE_DIM ]
2068
+ cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
2069
+
2070
+ with pm .Model (coords = temp_coords ) as forecast_model :
2071
+ (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2072
+ data_dims = ["data_time" , OBS_STATE_DIM ],
2073
+ )
2074
+
2075
+ group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
2076
+ mu , cov = grouped_outputs [group_idx ]
2077
+
2078
+ sub_dict = {
2079
+ data_var : pt .as_tensor_variable (data_var .get_value (), name = "data" )
2080
+ for data_var in forecast_model .data_vars
2081
+ }
2082
+
2083
+ missing_data_vars = np .setdiff1d (
2084
+ ar1 = [* self .data_names , "data" ], ar2 = [k .name for k , _ in sub_dict .items ()]
2085
+ )
2086
+ if missing_data_vars .size > 0 :
2087
+ raise ValueError (f"{ missing_data_vars } data used for fitting not found!" )
2088
+
2089
+ mu_frozen , cov_frozen = graph_replace ([mu , cov ], replace = sub_dict , strict = True )
2090
+
2091
+ x0 = pm .Deterministic (
2092
+ "x0_slice" , mu_frozen [t0_idx ], dims = mu_dims [1 :] if mu_dims is not None else None
2093
+ )
2094
+ P0 = pm .Deterministic (
2095
+ "P0_slice" , cov_frozen [t0_idx ], dims = cov_dims [1 :] if cov_dims is not None else None
2096
+ )
2097
+
2098
+ _ = LinearGaussianStateSpace (
2099
+ "forecast" ,
2100
+ x0 ,
2101
+ P0 ,
2102
+ * matrices ,
2103
+ steps = len (forecast_index ),
2104
+ dims = dims ,
2105
+ sequence_names = self .kalman_filter .seq_names ,
2106
+ k_endog = self .k_endog ,
2107
+ append_x0 = False ,
2108
+ method = mvn_method ,
2109
+ )
2110
+
2111
+ return forecast_model
2112
+
2050
2113
def forecast (
2051
2114
self ,
2052
2115
idata : InferenceData ,
@@ -2139,8 +2202,6 @@ def forecast(
2139
2202
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
2140
2203
2141
2204
"""
2142
- filter_time_dim = TIME_DIM
2143
-
2144
2205
_validate_filter_arg (filter_output )
2145
2206
2146
2207
compile_kwargs = kwargs .pop ("compile_kwargs" , {})
@@ -2185,58 +2246,23 @@ def forecast(
2185
2246
use_scenario_index = use_scenario_index ,
2186
2247
)
2187
2248
scenario = self ._finalize_scenario_initialization (scenario , forecast_index )
2188
- temp_coords = self ._fit_coords .copy ()
2189
-
2190
- dims = None
2191
- if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
2192
- dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
2193
-
2194
- t0_idx = np .flatnonzero (time_index == t0 )[0 ]
2195
-
2196
- temp_coords ["data_time" ] = time_index
2197
- temp_coords [TIME_DIM ] = forecast_index
2198
-
2199
- mu_dims , cov_dims = None , None
2200
- if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
2201
- mu_dims = ["data_time" , ALL_STATE_DIM ]
2202
- cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
2203
-
2204
- with pm .Model (coords = temp_coords ) as forecast_model :
2205
- (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2206
- scenario = scenario ,
2207
- data_dims = ["data_time" , OBS_STATE_DIM ],
2208
- )
2209
-
2210
- for name in self .data_names :
2211
- if name in scenario .keys ():
2212
- pm .set_data (
2213
- {"data" : np .zeros ((len (forecast_index ), self .k_endog ))},
2214
- coords = {"data_time" : np .arange (len (forecast_index ))},
2215
- )
2216
- break
2217
2249
2218
- group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
2219
- mu , cov = grouped_outputs [group_idx ]
2220
-
2221
- x0 = pm .Deterministic (
2222
- "x0_slice" , mu [t0_idx ], dims = mu_dims [1 :] if mu_dims is not None else None
2223
- )
2224
- P0 = pm .Deterministic (
2225
- "P0_slice" , cov [t0_idx ], dims = cov_dims [1 :] if cov_dims is not None else None
2226
- )
2250
+ forecast_model = self ._build_forecast_model (
2251
+ time_index = time_index ,
2252
+ t0 = t0 ,
2253
+ forecast_index = forecast_index ,
2254
+ scenario = scenario ,
2255
+ filter_output = filter_output ,
2256
+ mvn_method = mvn_method ,
2257
+ )
2227
2258
2228
- _ = LinearGaussianStateSpace (
2229
- "forecast" ,
2230
- x0 ,
2231
- P0 ,
2232
- * matrices ,
2233
- steps = len (forecast_index ),
2234
- dims = dims ,
2235
- sequence_names = self .kalman_filter .seq_names ,
2236
- k_endog = self .k_endog ,
2237
- append_x0 = False ,
2238
- method = mvn_method ,
2239
- )
2259
+ with forecast_model :
2260
+ if scenario is not None :
2261
+ dummy_obs_data = np .zeros ((len (forecast_index ), self .k_endog ))
2262
+ pm .set_data (
2263
+ scenario | {"data" : dummy_obs_data },
2264
+ coords = {"data_time" : np .arange (len (forecast_index ))},
2265
+ )
2240
2266
2241
2267
forecast_model .rvs_to_initial_values = {
2242
2268
k : None for k in forecast_model .rvs_to_initial_values .keys ()
0 commit comments