@@ -757,8 +757,6 @@ def _create_windows(self, batch, step):
757757 insample_condition >= min_insample_points
758758 )
759759
760- windows = windows [final_condition ]
761-
762760 # Parse Static data to match windows
763761 static = batch .get ("static" , None )
764762 static_cols = batch .get ("static_cols" , None )
@@ -768,13 +766,14 @@ def _create_windows(self, batch, step):
768766 static = torch .repeat_interleave (
769767 static , repeats = windows_per_serie , dim = 0
770768 )
771- static = static [final_condition ]
772769
773770 # Protection of empty windows
774771 if final_condition .sum () == 0 :
775772 raise Exception ("No windows available for training" )
776773
777- return windows , static , static_cols
774+ final_condition = torch .nonzero (final_condition ).squeeze (- 1 )
775+
776+ return windows , static , static_cols , final_condition
778777
779778 elif step in ["predict" , "val" ]:
780779
@@ -837,7 +836,9 @@ def _create_windows(self, batch, step):
837836 static , repeats = windows_per_serie , dim = 0
838837 )
839838
840- return windows , static , static_cols
839+ final_condition = torch .arange (windows .shape [0 ], device = windows .device )
840+
841+ return windows , static , static_cols , final_condition
841842 else :
842843 raise ValueError (f"Unknown step { step } " )
843844
@@ -885,21 +886,13 @@ def _inv_normalization(self, y_hat, y_idx):
885886 return y_hat
886887
887888 def _sample_windows (
888- self , windows_temporal , static , static_cols , temporal_cols , step , w_idxs = None
889+ self , windows_temporal , static , static_cols , temporal_cols , w_idxs , final_condition ,
889890 ):
890- if step == "train" and self .windows_batch_size is not None :
891- n_windows = windows_temporal .shape [0 ]
892- w_idxs = np .random .choice (
893- n_windows ,
894- size = self .windows_batch_size ,
895- replace = (n_windows < self .windows_batch_size ),
896- )
897- windows_sample = windows_temporal
898- if w_idxs is not None :
899- windows_sample = windows_temporal [w_idxs ]
891+ w_idxs_final = final_condition [w_idxs ]
892+ windows_sample = windows_temporal [w_idxs_final ]
900893
901- if static is not None and not self .MULTIVARIATE :
902- static = static [w_idxs ]
894+ if static is not None and not self .MULTIVARIATE :
895+ static = static [w_idxs_final ]
903896
904897 windows_batch = dict (
905898 temporal = windows_sample ,
@@ -1293,10 +1286,10 @@ def _predict_step_direct_batch(
12931286 def _predict_step_recurrent (self , batch , batch_idx ):
12941287 self .input_size = self .inference_input_size
12951288 temporal_cols = batch ["temporal_cols" ]
1296- windows_temporal , static , static_cols = self ._create_windows (
1289+ windows_temporal , static , static_cols , final_condition = self ._create_windows (
12971290 batch , step = "predict"
12981291 )
1299- n_windows = len (windows_temporal )
1292+ n_windows = len (final_condition )
13001293 y_idx = batch ["y_idx" ]
13011294
13021295 # Number of windows in batch
@@ -1316,16 +1309,16 @@ def _predict_step_recurrent(self, batch, batch_idx):
13161309
13171310 for i in range (n_batches ):
13181311 # Create and normalize windows [Ws, L+H, C]
1319- w_idxs = np .arange (
1320- i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows )
1312+ w_idxs = torch .arange (
1313+ i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows ), device = windows_temporal . device
13211314 )
13221315 windows = self ._sample_windows (
1323- windows_temporal ,
1324- static ,
1325- static_cols ,
1326- temporal_cols ,
1327- step = "predict" ,
1316+ windows_temporal = windows_temporal ,
1317+ static = static ,
1318+ static_cols = static_cols ,
1319+ temporal_cols = temporal_cols ,
13281320 w_idxs = w_idxs ,
1321+ final_condition = final_condition ,
13291322 )
13301323 windows = self ._normalization (windows = windows , y_idx = y_idx )
13311324
@@ -1405,10 +1398,10 @@ def _compute_explanations_for_step(
14051398 ):
14061399 """Compute explanations for a single prediction step."""
14071400 # Create windows and normalize for explanations
1408- windows_temporal , static , static_cols = self ._create_windows (
1401+ windows_temporal , static , static_cols , final_condition = self ._create_windows (
14091402 batch , step = "predict"
14101403 )
1411- n_windows = len (windows_temporal )
1404+ n_windows = len (final_condition )
14121405
14131406 # Process windows in batches
14141407 windows_batch_size = self .inference_windows_batch_size
@@ -1423,16 +1416,16 @@ def _compute_explanations_for_step(
14231416 step_baseline_predictions = []
14241417
14251418 for j in range (n_batches ):
1426- w_idxs = np .arange (
1427- j * windows_batch_size , min ((j + 1 ) * windows_batch_size , n_windows )
1419+ w_idxs = torch .arange (
1420+ j * windows_batch_size , min ((j + 1 ) * windows_batch_size , n_windows ), device = windows_temporal . device
14281421 )
14291422 windows = self ._sample_windows (
1430- windows_temporal ,
1431- static ,
1432- static_cols ,
1433- temporal_cols ,
1434- step = "predict" ,
1423+ windows_temporal = windows_temporal ,
1424+ static = static ,
1425+ static_cols = static_cols ,
1426+ temporal_cols = temporal_cols ,
14351427 w_idxs = w_idxs ,
1428+ final_condition = final_condition ,
14361429 )
14371430 windows = self ._normalization (windows = windows , y_idx = y_idx )
14381431
@@ -1584,11 +1577,11 @@ def _predict_step_direct(self, batch, batch_idx, recursive=False):
15841577
15851578 else :
15861579 # Non-recursive case remains unchanged
1587- windows_temporal , static , static_cols = self ._create_windows (
1580+ windows_temporal , static , static_cols , final_condition = self ._create_windows (
15881581 batch ,
15891582 step = "predict" ,
15901583 )
1591- n_windows = len (windows_temporal )
1584+ n_windows = len (final_condition )
15921585 y_idx = batch ["y_idx" ]
15931586
15941587 # Number of windows in batch
@@ -1607,16 +1600,16 @@ def _predict_step_direct(self, batch, batch_idx, recursive=False):
16071600
16081601 for i in range (n_batches ):
16091602 # Create and normalize windows [Ws, L+H, C]
1610- w_idxs = np .arange (
1611- i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows )
1603+ w_idxs = torch .arange (
1604+ i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows ), device = windows_temporal . device
16121605 )
16131606 windows = self ._sample_windows (
1614- windows_temporal ,
1615- static ,
1616- static_cols ,
1617- temporal_cols ,
1618- step = "predict" ,
1607+ windows_temporal = windows_temporal ,
1608+ static = static ,
1609+ static_cols = static_cols ,
1610+ temporal_cols = temporal_cols ,
16191611 w_idxs = w_idxs ,
1612+ final_condition = final_condition ,
16201613 )
16211614 windows = self ._normalization (windows = windows , y_idx = y_idx )
16221615
@@ -1706,11 +1699,26 @@ def training_step(self, batch, batch_idx):
17061699 y_idx = batch ["y_idx" ]
17071700
17081701 temporal_cols = batch ["temporal_cols" ]
1709- windows_temporal , static , static_cols = self ._create_windows (
1702+ windows_temporal , static , static_cols , final_condition = self ._create_windows (
17101703 batch , step = "train"
17111704 )
1705+ n_windows = len (final_condition )
1706+ if self .windows_batch_size is not None :
1707+ if n_windows < self .windows_batch_size :
1708+ w_idxs = torch .randint (
1709+ 0 ,
1710+ n_windows ,
1711+ size = (self .windows_batch_size ,),
1712+ device = windows_temporal .device ,
1713+ )
1714+ else :
1715+ w_idxs = torch .randperm (n_windows , device = windows_temporal .device )[
1716+ : self .windows_batch_size
1717+ ]
1718+ else :
1719+ w_idxs = torch .arange (n_windows , device = windows_temporal .device )
17121720 windows = self ._sample_windows (
1713- windows_temporal , static , static_cols , temporal_cols , step = "train"
1721+ windows_temporal = windows_temporal , static = static , static_cols = static_cols , temporal_cols = temporal_cols , w_idxs = w_idxs , final_condition = final_condition
17141722 )
17151723 original_outsample_y = torch .clone (
17161724 windows ["temporal" ][:, self .input_size :, y_idx ]
@@ -1777,8 +1785,8 @@ def validation_step(self, batch, batch_idx):
17771785 return np .nan
17781786
17791787 temporal_cols = batch ["temporal_cols" ]
1780- windows_temporal , static , static_cols = self ._create_windows (batch , step = "val" )
1781- n_windows = len (windows_temporal )
1788+ windows_temporal , static , static_cols , final_condition = self ._create_windows (batch , step = "val" )
1789+ n_windows = len (final_condition )
17821790 y_idx = batch ["y_idx" ]
17831791
17841792 # Number of windows in batch
@@ -1791,16 +1799,16 @@ def validation_step(self, batch, batch_idx):
17911799 batch_sizes = []
17921800 for i in range (n_batches ):
17931801 # Create and normalize windows [Ws, L + h, C, n_series]
1794- w_idxs = np .arange (
1795- i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows )
1802+ w_idxs = torch .arange (
1803+ i * windows_batch_size , min ((i + 1 ) * windows_batch_size , n_windows ), device = windows_temporal . device
17961804 )
17971805 windows = self ._sample_windows (
17981806 windows_temporal ,
17991807 static ,
18001808 static_cols ,
18011809 temporal_cols ,
1802- step = "val" ,
18031810 w_idxs = w_idxs ,
1811+ final_condition = final_condition ,
18041812 )
18051813 original_outsample_y = torch .clone (
18061814 windows ["temporal" ][:, self .input_size :, y_idx ]
0 commit comments