Skip to content

Commit 2128169

Browse files
authored
subset-selection
Add on-the-fly data subset selection.
1 parent 7bdac5d commit 2128169

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

hotpp/data/dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
7979
"""
8080
def __init__(self, data,
8181
min_length=0, max_length=None,
82+
random_split=1,
83+
random_part="train",
8284
position="random",
8385
min_required_length=None,
8486
fields=None,
@@ -101,6 +103,8 @@ def __init__(self, data,
101103
raise RuntimeError("Empty dataset")
102104
self.allow_empty = allow_empty
103105
self.total_length = sum(map(get_parquet_length, self.filenames))
106+
self.random_split = random_split
107+
self.random_part = random_part
104108
self.min_length = min_length
105109
self.max_length = max_length
106110
self.position = position
@@ -181,7 +185,15 @@ def __len__(self):
181185
return self.total_length
182186

183187
def __iter__(self):
188+
if self.filenames:
189+
root = os.path.commonprefix(self.filenames)
184190
for filename in self.filenames:
191+
if (self.random_split != 1) or (self.random_part != "train"):
192+
s = 1000000000
193+
h = immutable_hash(os.path.relpath(filename, root))
194+
in_train = h % s <= s * self.random_split
195+
if in_train ^ (self.random_part == "train"):
196+
continue
185197
for rec in read_pyarrow_file(filename, use_threads=True):
186198
if (self.min_required_length is not None) and (len(rec[self.timestamps_field]) < self.min_required_length):
187199
continue

0 commit comments

Comments
 (0)