@@ -73,6 +73,7 @@ class HotppDataset(torch.utils.data.IterableDataset):
7373 min_length: Minimum sequence length. Use 0 to disable subsampling.
7474 max_length: Maximum sequence length. Disable limit if `None`.
7575 position: Sample position (`random` or `last`).
76+ rename: A dictionary for mapping field names during read.
7677 fields: A list of fields to keep in data. Other fields will be discarded.
7778 drop_nans: A list of fields to skip nans for.
7879 add_seq_fields: A dictionary with additional constant fields.
@@ -86,6 +87,7 @@ def __init__(self, data,
8687 random_part = "train" ,
8788 position = "random" ,
8889 min_required_length = None ,
90+ rename = None ,
8991 fields = None ,
9092 id_field = "id" ,
9193 timestamps_field = "timestamps" ,
@@ -121,6 +123,8 @@ def __init__(self, data,
121123 self .local_targets_fields = parse_fields (local_targets_fields )
122124 self .local_targets_indices_field = local_targets_indices_field
123125
126+ self .rename = rename or {}
127+
124128 if fields is not None :
125129 known_fields = [id_field , timestamps_field ] + list (self .global_target_fields ) + list (self .local_targets_fields )
126130 if local_targets_indices_field is not None :
@@ -196,6 +200,10 @@ def __iter__(self):
196200 if in_train ^ (self .random_part == "train" ):
197201 continue
198202 for rec in read_pyarrow_file (filename , use_threads = True ):
203+ for src , dst in self .rename .items ():
204+ if src not in rec :
205+ raise RuntimeError (f"The field `{ src } ` not found" )
206+ rec [dst ] = rec .pop (src )
199207 if (self .min_required_length is not None ) and (len (rec [self .timestamps_field ]) < self .min_required_length ):
200208 continue
201209 if self .fields is not None :
0 commit comments