Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions hotpp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
"""
def __init__(self, data,
min_length=0, max_length=None,
random_split=1,
random_part="train",
position="random",
min_required_length=None,
fields=None,
Expand All @@ -101,6 +103,8 @@ def __init__(self, data,
raise RuntimeError("Empty dataset")
self.allow_empty = allow_empty
self.total_length = sum(map(get_parquet_length, self.filenames))
self.random_split = random_split
self.random_part = random_part
self.min_length = min_length
self.max_length = max_length
self.position = position
Expand Down Expand Up @@ -181,7 +185,15 @@ def __len__(self):
return self.total_length

def __iter__(self):
if self.filenames:
root = os.path.commonprefix(self.filenames)
for filename in self.filenames:
if (self.random_split != 1) or (self.random_part != "train"):
s = 1000000000
h = immutable_hash(os.path.relpath(filename, root))
in_train = h % s <= s * self.random_split
if in_train ^ (self.random_part == "train"):
continue
for rec in read_pyarrow_file(filename, use_threads=True):
if (self.min_required_length is not None) and (len(rec[self.timestamps_field]) < self.min_required_length):
continue
Expand Down