diff --git a/ts_benchmark/baselines/deep_forecasting_model_base.py b/ts_benchmark/baselines/deep_forecasting_model_base.py index 51c4929..fa7df90 100644 --- a/ts_benchmark/baselines/deep_forecasting_model_base.py +++ b/ts_benchmark/baselines/deep_forecasting_model_base.py @@ -715,6 +715,7 @@ def forecast( output = output[:, -config.horizon:, :series_dim] break + column_num = output.shape[-1] temp = output.cpu().numpy().reshape(-1, column_num)[-config.horizon:] @@ -873,7 +874,7 @@ def _perform_rolling_predictions( """ rolling_time = 0 input_np, target_np, input_mark_np, target_mark_np = self._get_rolling_data( - input_np, None, all_mark, rolling_time + input_np, None, None, all_mark, rolling_time ) if exog_future is not None: rolling_time_sum = horizon // self.config.horizon + 1 @@ -933,12 +934,13 @@ def _perform_rolling_predictions( break rolling_time += 1 output = output.cpu().numpy()[:, -self.config.horizon :, :] + # 这里的_get_rolling_data函数需要修改: ( input_np, target_np, input_mark_np, target_mark_np, - ) = self._get_rolling_data(input_np, output, all_mark, rolling_time) + ) = self._get_rolling_data(input_np, output, exog_future, all_mark, rolling_time) answers = np.concatenate(answers, axis=1) return answers[:, -horizon:, :] @@ -947,6 +949,7 @@ def _get_rolling_data( self, input_np: np.ndarray, output: Optional[np.ndarray], + exog_future: Optional[torch.Tensor], all_mark: np.ndarray, rolling_time: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: @@ -960,6 +963,15 @@ def _get_rolling_data( :return: Updated input data, target data, input marks, and target marks for rolling prediction. """ if rolling_time > 0: + if exog_future!= None: + exog_output = exog_future[ + :, + rolling_time + * self.config.horizon : (rolling_time + 1) + * self.config.horizon, + :, + ].cpu().numpy() + output = np.concatenate((output, exog_output), axis=2) input_np = np.concatenate((input_np, output), axis=1) input_np = input_np[:, -self.config.seq_len :, :] target_np = np.zeros(