@@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(
20472047
20482048 return scenario
20492049
2050+ def _build_forecast_model (
2051+ self , time_index , t0 , forecast_index , scenario , filter_output , mvn_method
2052+ ):
2053+ filter_time_dim = TIME_DIM
2054+ temp_coords = self ._fit_coords .copy ()
2055+
2056+ dims = None
2057+ if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
2058+ dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
2059+
2060+ t0_idx = np .flatnonzero (time_index == t0 )[0 ]
2061+
2062+ temp_coords ["data_time" ] = time_index
2063+ temp_coords [TIME_DIM ] = forecast_index
2064+
2065+ mu_dims , cov_dims = None , None
2066+ if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
2067+ mu_dims = ["data_time" , ALL_STATE_DIM ]
2068+ cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
2069+
2070+ with pm .Model (coords = temp_coords ) as forecast_model :
2071+ (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2072+ data_dims = ["data_time" , OBS_STATE_DIM ],
2073+ )
2074+
2075+ group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
2076+ mu , cov = grouped_outputs [group_idx ]
2077+
2078+ sub_dict = {
2079+ data_var : pt .as_tensor_variable (data_var .get_value (), name = "data" )
2080+ for data_var in forecast_model .data_vars
2081+ }
2082+
2083+ missing_data_vars = np .setdiff1d (
2084+ ar1 = [* self .data_names , "data" ], ar2 = [k .name for k , _ in sub_dict .items ()]
2085+ )
2086+ if missing_data_vars .size > 0 :
2087+ raise ValueError (f"{ missing_data_vars } data used for fitting not found!" )
2088+
2089+ mu_frozen , cov_frozen = graph_replace ([mu , cov ], replace = sub_dict , strict = True )
2090+
2091+ x0 = pm .Deterministic (
2092+ "x0_slice" , mu_frozen [t0_idx ], dims = mu_dims [1 :] if mu_dims is not None else None
2093+ )
2094+ P0 = pm .Deterministic (
2095+ "P0_slice" , cov_frozen [t0_idx ], dims = cov_dims [1 :] if cov_dims is not None else None
2096+ )
2097+
2098+ _ = LinearGaussianStateSpace (
2099+ "forecast" ,
2100+ x0 ,
2101+ P0 ,
2102+ * matrices ,
2103+ steps = len (forecast_index ),
2104+ dims = dims ,
2105+ sequence_names = self .kalman_filter .seq_names ,
2106+ k_endog = self .k_endog ,
2107+ append_x0 = False ,
2108+ method = mvn_method ,
2109+ )
2110+
2111+ return forecast_model
2112+
20502113 def forecast (
20512114 self ,
20522115 idata : InferenceData ,
@@ -2139,8 +2202,6 @@ def forecast(
21392202 the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
21402203
21412204 """
2142- filter_time_dim = TIME_DIM
2143-
21442205 _validate_filter_arg (filter_output )
21452206
21462207 compile_kwargs = kwargs .pop ("compile_kwargs" , {})
@@ -2185,58 +2246,23 @@ def forecast(
21852246 use_scenario_index = use_scenario_index ,
21862247 )
21872248 scenario = self ._finalize_scenario_initialization (scenario , forecast_index )
2188- temp_coords = self ._fit_coords .copy ()
2189-
2190- dims = None
2191- if all ([dim in temp_coords for dim in [filter_time_dim , ALL_STATE_DIM , OBS_STATE_DIM ]]):
2192- dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
2193-
2194- t0_idx = np .flatnonzero (time_index == t0 )[0 ]
2195-
2196- temp_coords ["data_time" ] = time_index
2197- temp_coords [TIME_DIM ] = forecast_index
2198-
2199- mu_dims , cov_dims = None , None
2200- if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]]):
2201- mu_dims = ["data_time" , ALL_STATE_DIM ]
2202- cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
2203-
2204- with pm .Model (coords = temp_coords ) as forecast_model :
2205- (_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2206- scenario = scenario ,
2207- data_dims = ["data_time" , OBS_STATE_DIM ],
2208- )
2209-
2210- for name in self .data_names :
2211- if name in scenario .keys ():
2212- pm .set_data (
2213- {"data" : np .zeros ((len (forecast_index ), self .k_endog ))},
2214- coords = {"data_time" : np .arange (len (forecast_index ))},
2215- )
2216- break
22172249
2218- group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
2219- mu , cov = grouped_outputs [group_idx ]
2220-
2221- x0 = pm .Deterministic (
2222- "x0_slice" , mu [t0_idx ], dims = mu_dims [1 :] if mu_dims is not None else None
2223- )
2224- P0 = pm .Deterministic (
2225- "P0_slice" , cov [t0_idx ], dims = cov_dims [1 :] if cov_dims is not None else None
2226- )
2250+ forecast_model = self ._build_forecast_model (
2251+ time_index = time_index ,
2252+ t0 = t0 ,
2253+ forecast_index = forecast_index ,
2254+ scenario = scenario ,
2255+ filter_output = filter_output ,
2256+ mvn_method = mvn_method ,
2257+ )
22272258
2228- _ = LinearGaussianStateSpace (
2229- "forecast" ,
2230- x0 ,
2231- P0 ,
2232- * matrices ,
2233- steps = len (forecast_index ),
2234- dims = dims ,
2235- sequence_names = self .kalman_filter .seq_names ,
2236- k_endog = self .k_endog ,
2237- append_x0 = False ,
2238- method = mvn_method ,
2239- )
2259+ with forecast_model :
2260+ if scenario is not None :
2261+ dummy_obs_data = np .zeros ((len (forecast_index ), self .k_endog ))
2262+ pm .set_data (
2263+ scenario | {"data" : dummy_obs_data },
2264+ coords = {"data_time" : np .arange (len (forecast_index ))},
2265+ )
22402266
22412267 forecast_model .rvs_to_initial_values = {
22422268 k : None for k in forecast_model .rvs_to_initial_values .keys ()
0 commit comments