Skip to content

Commit 7250f51

Browse files
authored
Future (#76)
Add fields rename to the dataset.
1 parent 9c079fd commit 7250f51

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

hotpp/data/dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class HotppDataset(torch.utils.data.IterableDataset):
7373
min_length: Minimum sequence length. Use 0 to disable subsampling.
7474
max_length: Maximum sequence length. Disable limit if `None`.
7575
position: Sample position (`random` or `last`).
76+
rename: A dictionary for mapping field names during read.
7677
fields: A list of fields to keep in data. Other fields will be discarded.
7778
drop_nans: A list of fields to skip nans for.
7879
add_seq_fields: A dictionary with additional constant fields.
@@ -86,6 +87,7 @@ def __init__(self, data,
8687
random_part="train",
8788
position="random",
8889
min_required_length=None,
90+
rename=None,
8991
fields=None,
9092
id_field="id",
9193
timestamps_field="timestamps",
@@ -121,6 +123,8 @@ def __init__(self, data,
121123
self.local_targets_fields = parse_fields(local_targets_fields)
122124
self.local_targets_indices_field = local_targets_indices_field
123125

126+
self.rename = rename or {}
127+
124128
if fields is not None:
125129
known_fields = [id_field, timestamps_field] + list(self.global_target_fields) + list(self.local_targets_fields)
126130
if local_targets_indices_field is not None:
@@ -196,6 +200,10 @@ def __iter__(self):
196200
if in_train ^ (self.random_part == "train"):
197201
continue
198202
for rec in read_pyarrow_file(filename, use_threads=True):
203+
for src, dst in self.rename.items():
204+
if src not in rec:
205+
raise RuntimeError(f"The field `{src}` not found")
206+
rec[dst] = rec.pop(src)
199207
if (self.min_required_length is not None) and (len(rec[self.timestamps_field]) < self.min_required_length):
200208
continue
201209
if self.fields is not None:

0 commit comments

Comments
 (0)