@@ -230,11 +230,12 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]:
230230 targ_list = {}
231231 for va in self .listed_vals :
232232 # We need to exclude the index column on one end and the series id column on the other
233- t = torch .Tensor (va .iloc [idx : self .forecast_history + idx ].values )[:, 1 :- 1 ]
234- print (t .shape )
233+
235234 targ_start_idx = idx + self .forecast_history
236235 idx2 = va [self .series_id_col ].iloc [0 ]
237- targ = torch .Tensor (va .iloc [targ_start_idx : targ_start_idx + self .forecast_length ].to_numpy ())[:, 1 :- 1 ]
236+ va_returned = va [va .columns .difference ([self .series_id_col ], sort = False )]
237+ t = torch .Tensor (va_returned .iloc [idx : self .forecast_history + idx ].values )[:, 1 :]
238+ targ = torch .Tensor (va_returned .iloc [targ_start_idx : targ_start_idx + self .forecast_length ].to_numpy ())[:, 1 :] # noqa
238239 src_list [self .unique_dict [idx2 ]] = t
239240 targ_list [self .unique_dict [idx2 ]] = targ
240241 return src_list , targ_list
@@ -249,7 +250,7 @@ def __len__(self) -> int:
249250 if self .return_all_series :
250251 return len (self .listed_vals [0 ]) - self .forecast_history - self .forecast_length - 1
251252 else :
252- raise NotImplementedError ("Current code only supports returning all the series at each iteration" )
253+ raise NotImplementedError ("Current code only supports returning all the series at once at each iteration" )
253254
254255
255256class CSVTestLoader (CSVDataLoader ):
@@ -667,11 +668,11 @@ class SeriesIDTestLoader(CSVSeriesIDLoader):
667668 def __init__ (self , series_id_col : str , main_params : dict , return_method : str , forecast_total = 336 , return_all = True ):
668669 """_summary_
669670
670- :param series_id_col: _de
671+ :param series_id_col: The column that contains the series_id
671672 :type series_id_col: str
672- :param main_params: _description_
673+ :param main_params: The core params used to instantiate the CSVSeriesIDLoader
673674 :type main_params: dict
674- :param return_method: _description_
675+ :param return_method: _description_D
675676 :type return_method: str
676677 :param return_all: _description_, defaults to True
677678 :type return_all: bool, optional
0 commit comments