Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions ts_benchmark/baselines/deep_forecasting_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def forecast(
output = output[:, -config.horizon:, :series_dim]
break


column_num = output.shape[-1]
temp = output.cpu().numpy().reshape(-1, column_num)[-config.horizon:]

Expand Down Expand Up @@ -873,7 +874,7 @@ def _perform_rolling_predictions(
"""
rolling_time = 0
input_np, target_np, input_mark_np, target_mark_np = self._get_rolling_data(
input_np, None, all_mark, rolling_time
input_np, None, None, all_mark, rolling_time
)
if exog_future is not None:
rolling_time_sum = horizon // self.config.horizon + 1
Expand Down Expand Up @@ -933,12 +934,13 @@ def _perform_rolling_predictions(
break
rolling_time += 1
output = output.cpu().numpy()[:, -self.config.horizon :, :]
# 这里的_get_rolling_data函数需要修改:
(
input_np,
target_np,
input_mark_np,
target_mark_np,
) = self._get_rolling_data(input_np, output, all_mark, rolling_time)
) = self._get_rolling_data(input_np, output, exog_future, all_mark, rolling_time)

answers = np.concatenate(answers, axis=1)
return answers[:, -horizon:, :]
Expand All @@ -947,6 +949,7 @@ def _get_rolling_data(
self,
input_np: np.ndarray,
output: Optional[np.ndarray],
exog_future: Optional[torch.Tensor],
all_mark: np.ndarray,
rolling_time: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
Expand All @@ -960,6 +963,15 @@ def _get_rolling_data(
:return: Updated input data, target data, input marks, and target marks for rolling prediction.
"""
if rolling_time > 0:
if exog_future!= None:
exog_output = exog_future[
:,
rolling_time
* self.config.horizon : (rolling_time + 1)
* self.config.horizon,
:,
].cpu().numpy()
output = np.concatenate((output, exog_output), axis=2)
input_np = np.concatenate((input_np, output), axis=1)
input_np = input_np[:, -self.config.seq_len :, :]
target_np = np.zeros(
Expand Down