@@ -79,6 +79,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
7979 """
8080 def __init__ (self , data ,
8181 min_length = 0 , max_length = None ,
82+ random_split = 1 ,
83+ random_part = "train" ,
8284 position = "random" ,
8385 min_required_length = None ,
8486 fields = None ,
@@ -101,6 +103,8 @@ def __init__(self, data,
101103 raise RuntimeError ("Empty dataset" )
102104 self .allow_empty = allow_empty
103105 self .total_length = sum (map (get_parquet_length , self .filenames ))
106+ self .random_split = random_split
107+ self .random_part = random_part
104108 self .min_length = min_length
105109 self .max_length = max_length
106110 self .position = position
@@ -181,7 +185,15 @@ def __len__(self):
181185 return self .total_length
182186
183187 def __iter__ (self ):
188+ if self .filenames :
189+ root = os .path .commonprefix (self .filenames )
184190 for filename in self .filenames :
191+ if (self .random_split != 1 ) or (self .random_part != "train" ):
192+ s = 1000000000
193+ h = immutable_hash (os .path .relpath (filename , root ))
194+ in_train = h % s <= s * self .random_split
195+ if in_train ^ (self .random_part == "train" ):
196+ continue
185197 for rec in read_pyarrow_file (filename , use_threads = True ):
186198 if (self .min_required_length is not None ) and (len (rec [self .timestamps_field ]) < self .min_required_length ):
187199 continue
0 commit comments