Skip to content

Commit 66aeb5e

Browse files
marcopeixelephaint
andauthored
[FIX] Apply windows batch size (#1429)
Co-authored-by: elephaint <osprangers@gmail.com>
1 parent a2bcd96 commit 66aeb5e

File tree

1 file changed

+60
-52
lines changed

1 file changed

+60
-52
lines changed

neuralforecast/common/_base_model.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,6 @@ def _create_windows(self, batch, step):
757757
insample_condition >= min_insample_points
758758
)
759759

760-
windows = windows[final_condition]
761-
762760
# Parse Static data to match windows
763761
static = batch.get("static", None)
764762
static_cols = batch.get("static_cols", None)
@@ -768,13 +766,14 @@ def _create_windows(self, batch, step):
768766
static = torch.repeat_interleave(
769767
static, repeats=windows_per_serie, dim=0
770768
)
771-
static = static[final_condition]
772769

773770
# Protection of empty windows
774771
if final_condition.sum() == 0:
775772
raise Exception("No windows available for training")
776773

777-
return windows, static, static_cols
774+
final_condition = torch.nonzero(final_condition).squeeze(-1)
775+
776+
return windows, static, static_cols, final_condition
778777

779778
elif step in ["predict", "val"]:
780779

@@ -837,7 +836,9 @@ def _create_windows(self, batch, step):
837836
static, repeats=windows_per_serie, dim=0
838837
)
839838

840-
return windows, static, static_cols
839+
final_condition = torch.arange(windows.shape[0], device=windows.device)
840+
841+
return windows, static, static_cols, final_condition
841842
else:
842843
raise ValueError(f"Unknown step {step}")
843844

@@ -885,21 +886,13 @@ def _inv_normalization(self, y_hat, y_idx):
885886
return y_hat
886887

887888
def _sample_windows(
888-
self, windows_temporal, static, static_cols, temporal_cols, step, w_idxs=None
889+
self, windows_temporal, static, static_cols, temporal_cols, w_idxs, final_condition,
889890
):
890-
if step == "train" and self.windows_batch_size is not None:
891-
n_windows = windows_temporal.shape[0]
892-
w_idxs = np.random.choice(
893-
n_windows,
894-
size=self.windows_batch_size,
895-
replace=(n_windows < self.windows_batch_size),
896-
)
897-
windows_sample = windows_temporal
898-
if w_idxs is not None:
899-
windows_sample = windows_temporal[w_idxs]
891+
w_idxs_final = final_condition[w_idxs]
892+
windows_sample = windows_temporal[w_idxs_final]
900893

901-
if static is not None and not self.MULTIVARIATE:
902-
static = static[w_idxs]
894+
if static is not None and not self.MULTIVARIATE:
895+
static = static[w_idxs_final]
903896

904897
windows_batch = dict(
905898
temporal=windows_sample,
@@ -1293,10 +1286,10 @@ def _predict_step_direct_batch(
12931286
def _predict_step_recurrent(self, batch, batch_idx):
12941287
self.input_size = self.inference_input_size
12951288
temporal_cols = batch["temporal_cols"]
1296-
windows_temporal, static, static_cols = self._create_windows(
1289+
windows_temporal, static, static_cols, final_condition = self._create_windows(
12971290
batch, step="predict"
12981291
)
1299-
n_windows = len(windows_temporal)
1292+
n_windows = len(final_condition)
13001293
y_idx = batch["y_idx"]
13011294

13021295
# Number of windows in batch
@@ -1316,16 +1309,16 @@ def _predict_step_recurrent(self, batch, batch_idx):
13161309

13171310
for i in range(n_batches):
13181311
# Create and normalize windows [Ws, L+H, C]
1319-
w_idxs = np.arange(
1320-
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows)
1312+
w_idxs = torch.arange(
1313+
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows), device=windows_temporal.device
13211314
)
13221315
windows = self._sample_windows(
1323-
windows_temporal,
1324-
static,
1325-
static_cols,
1326-
temporal_cols,
1327-
step="predict",
1316+
windows_temporal=windows_temporal,
1317+
static=static,
1318+
static_cols=static_cols,
1319+
temporal_cols=temporal_cols,
13281320
w_idxs=w_idxs,
1321+
final_condition=final_condition,
13291322
)
13301323
windows = self._normalization(windows=windows, y_idx=y_idx)
13311324

@@ -1405,10 +1398,10 @@ def _compute_explanations_for_step(
14051398
):
14061399
"""Compute explanations for a single prediction step."""
14071400
# Create windows and normalize for explanations
1408-
windows_temporal, static, static_cols = self._create_windows(
1401+
windows_temporal, static, static_cols, final_condition = self._create_windows(
14091402
batch, step="predict"
14101403
)
1411-
n_windows = len(windows_temporal)
1404+
n_windows = len(final_condition)
14121405

14131406
# Process windows in batches
14141407
windows_batch_size = self.inference_windows_batch_size
@@ -1423,16 +1416,16 @@ def _compute_explanations_for_step(
14231416
step_baseline_predictions = []
14241417

14251418
for j in range(n_batches):
1426-
w_idxs = np.arange(
1427-
j * windows_batch_size, min((j + 1) * windows_batch_size, n_windows)
1419+
w_idxs = torch.arange(
1420+
j * windows_batch_size, min((j + 1) * windows_batch_size, n_windows), device=windows_temporal.device
14281421
)
14291422
windows = self._sample_windows(
1430-
windows_temporal,
1431-
static,
1432-
static_cols,
1433-
temporal_cols,
1434-
step="predict",
1423+
windows_temporal=windows_temporal,
1424+
static=static,
1425+
static_cols=static_cols,
1426+
temporal_cols=temporal_cols,
14351427
w_idxs=w_idxs,
1428+
final_condition=final_condition,
14361429
)
14371430
windows = self._normalization(windows=windows, y_idx=y_idx)
14381431

@@ -1584,11 +1577,11 @@ def _predict_step_direct(self, batch, batch_idx, recursive=False):
15841577

15851578
else:
15861579
# Non-recursive case remains unchanged
1587-
windows_temporal, static, static_cols = self._create_windows(
1580+
windows_temporal, static, static_cols, final_condition = self._create_windows(
15881581
batch,
15891582
step="predict",
15901583
)
1591-
n_windows = len(windows_temporal)
1584+
n_windows = len(final_condition)
15921585
y_idx = batch["y_idx"]
15931586

15941587
# Number of windows in batch
@@ -1607,16 +1600,16 @@ def _predict_step_direct(self, batch, batch_idx, recursive=False):
16071600

16081601
for i in range(n_batches):
16091602
# Create and normalize windows [Ws, L+H, C]
1610-
w_idxs = np.arange(
1611-
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows)
1603+
w_idxs = torch.arange(
1604+
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows), device=windows_temporal.device
16121605
)
16131606
windows = self._sample_windows(
1614-
windows_temporal,
1615-
static,
1616-
static_cols,
1617-
temporal_cols,
1618-
step="predict",
1607+
windows_temporal=windows_temporal,
1608+
static=static,
1609+
static_cols=static_cols,
1610+
temporal_cols=temporal_cols,
16191611
w_idxs=w_idxs,
1612+
final_condition=final_condition,
16201613
)
16211614
windows = self._normalization(windows=windows, y_idx=y_idx)
16221615

@@ -1706,11 +1699,26 @@ def training_step(self, batch, batch_idx):
17061699
y_idx = batch["y_idx"]
17071700

17081701
temporal_cols = batch["temporal_cols"]
1709-
windows_temporal, static, static_cols = self._create_windows(
1702+
windows_temporal, static, static_cols, final_condition = self._create_windows(
17101703
batch, step="train"
17111704
)
1705+
n_windows = len(final_condition)
1706+
if self.windows_batch_size is not None:
1707+
if n_windows < self.windows_batch_size:
1708+
w_idxs = torch.randint(
1709+
0,
1710+
n_windows,
1711+
size=(self.windows_batch_size,),
1712+
device=windows_temporal.device,
1713+
)
1714+
else:
1715+
w_idxs = torch.randperm(n_windows, device=windows_temporal.device)[
1716+
: self.windows_batch_size
1717+
]
1718+
else:
1719+
w_idxs = torch.arange(n_windows, device=windows_temporal.device)
17121720
windows = self._sample_windows(
1713-
windows_temporal, static, static_cols, temporal_cols, step="train"
1721+
windows_temporal=windows_temporal, static=static, static_cols=static_cols, temporal_cols=temporal_cols, w_idxs=w_idxs, final_condition=final_condition
17141722
)
17151723
original_outsample_y = torch.clone(
17161724
windows["temporal"][:, self.input_size :, y_idx]
@@ -1777,8 +1785,8 @@ def validation_step(self, batch, batch_idx):
17771785
return np.nan
17781786

17791787
temporal_cols = batch["temporal_cols"]
1780-
windows_temporal, static, static_cols = self._create_windows(batch, step="val")
1781-
n_windows = len(windows_temporal)
1788+
windows_temporal, static, static_cols, final_condition = self._create_windows(batch, step="val")
1789+
n_windows = len(final_condition)
17821790
y_idx = batch["y_idx"]
17831791

17841792
# Number of windows in batch
@@ -1791,16 +1799,16 @@ def validation_step(self, batch, batch_idx):
17911799
batch_sizes = []
17921800
for i in range(n_batches):
17931801
# Create and normalize windows [Ws, L + h, C, n_series]
1794-
w_idxs = np.arange(
1795-
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows)
1802+
w_idxs = torch.arange(
1803+
i * windows_batch_size, min((i + 1) * windows_batch_size, n_windows), device=windows_temporal.device
17961804
)
17971805
windows = self._sample_windows(
17981806
windows_temporal,
17991807
static,
18001808
static_cols,
18011809
temporal_cols,
1802-
step="val",
18031810
w_idxs=w_idxs,
1811+
final_condition=final_condition,
18041812
)
18051813
original_outsample_y = torch.clone(
18061814
windows["temporal"][:, self.input_size :, y_idx]

0 commit comments

Comments
 (0)