Skip to content
8 changes: 8 additions & 0 deletions hotpp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class HotppDataset(torch.utils.data.IterableDataset):
min_length: Minimum sequence length. Use 0 to disable subsampling.
max_length: Maximum sequence length. Disable limit if `None`.
position: Sample position (`random` or `last`).
rename: A dictionary for mapping field names during read.
fields: A list of fields to keep in data. Other fields will be discarded.
drop_nans: A list of fields to skip nans for.
add_seq_fields: A dictionary with additional constant fields.
Expand All @@ -86,6 +87,7 @@ def __init__(self, data,
random_part="train",
position="random",
min_required_length=None,
rename=None,
fields=None,
id_field="id",
timestamps_field="timestamps",
Expand Down Expand Up @@ -121,6 +123,8 @@ def __init__(self, data,
self.local_targets_fields = parse_fields(local_targets_fields)
self.local_targets_indices_field = local_targets_indices_field

self.rename = rename or {}

if fields is not None:
known_fields = [id_field, timestamps_field] + list(self.global_target_fields) + list(self.local_targets_fields)
if local_targets_indices_field is not None:
Expand Down Expand Up @@ -196,6 +200,10 @@ def __iter__(self):
if in_train ^ (self.random_part == "train"):
continue
for rec in read_pyarrow_file(filename, use_threads=True):
for src, dst in self.rename.items():
if src not in rec:
raise RuntimeError(f"The field `{src}` not found")
rec[dst] = rec.pop(src)
if (self.min_required_length is not None) and (len(rec[self.timestamps_field]) < self.min_required_length):
continue
if self.fields is not None:
Expand Down