@@ -75,6 +75,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
7575 position: Sample position (`random` or `last`).
7676 rename: A dictionary for mapping field names during read.
7777 fields: A list of fields to keep in data. Other fields will be discarded.
78+ offset: Skip some initial records.
79+ limit: If set, limit the number of elements in the dataset.
7880 drop_nans: A list of fields to skip nans for.
7981 add_seq_fields: A dictionary with additional constant fields.
8082 global_target_fields: The name of the target field or a list of fields. Global targets are assigned to sequences.
@@ -91,11 +93,15 @@ def __init__(self, data,
9193 fields = None ,
9294 id_field = "id" ,
9395 timestamps_field = "timestamps" ,
96+ offset = 0 ,
97+ limit = None ,
9498 drop_nans = None ,
9599 add_seq_fields = None ,
96100 global_target_fields = None ,
97101 local_targets_fields = None ,
98102 local_targets_indices_field = None ):
103+ if (limit is not None ) and (min_required_length or drop_nans ):
104+ raise NotImplementedError ("Can't combine `limit` with input filters." )
99105 super ().__init__ ()
100106 if isinstance (data , str ):
101107 self .filenames = list (sorted (parquet_file_scan (data )))
@@ -105,9 +111,26 @@ def __init__(self, data,
105111 raise ValueError (f"Unknown data type: { type (data )} " )
106112 if not self .filenames :
107113 raise RuntimeError ("Empty dataset" )
108- self .total_length = sum (map (get_parquet_length , self .filenames ))
109- self .random_split = random_split
110- self .random_part = random_part
114+ if self .filenames and ((random_split != 1 ) or (random_part != "train" )):
115+ if limit is not None :
116+ raise NotImplementedError ("Can't combine `limit` with splitting." )
117+ if random_part not in {"train" , "val" }:
118+ raise ValueError (f"Unknown random part: { random_part } . Must be either `train` or `val`." )
119+ s = 1000000000
120+ root = os .path .commonprefix (self .filenames )
121+ selected_filenames = []
122+ for filename in self .filenames :
123+ h = immutable_hash (os .path .relpath (filename , root ))
124+ in_train = h % s <= s * random_split
125+ if not (in_train ^ (random_part == "train" )):
126+ selected_filenames .append (filename )
127+ self .filenames = selected_filenames
128+ self .offset = offset
129+ self .limit = limit
130+ self .total_length = max (0 , sum (map (get_parquet_length , self .filenames )) - offset )
131+ if self .limit is not None :
132+ self .total_length = min (self .limit , self .total_length )
133+
111134 self .min_length = min_length
112135 self .max_length = max_length
113136 self .position = position
@@ -134,7 +157,7 @@ def __init__(self, data,
134157
135158 def replace_files (self , filenames , ** kwargs ):
136159 names = set (inspect .signature (self .__init__ ).parameters .keys ())
137- names = names - {"self" , "data" }
160+ names = names - {"self" , "data" , "random_split" , "random_part" }
138161 kwargs = {name : getattr (self , name ) for name in names } | kwargs
139162 return HotppDataset (filenames , ** kwargs )
140163
@@ -190,16 +213,12 @@ def __len__(self):
190213 return self .total_length
191214
192215 def __iter__ (self ):
193- if self .filenames :
194- root = os .path .commonprefix (self .filenames )
216+ total = 0
195217 for filename in self .filenames :
196- if (self .random_split != 1 ) or (self .random_part != "train" ):
197- s = 1000000000
198- h = immutable_hash (os .path .relpath (filename , root ))
199- in_train = h % s <= s * self .random_split
200- if in_train ^ (self .random_part == "train" ):
218+ for rec in read_pyarrow_file (filename ):
219+ total += 1
220+ if total <= self .offset :
201221 continue
202- for rec in read_pyarrow_file (filename , use_threads = True ):
203222 for src , dst in self .rename .items ():
204223 if src not in rec :
205224 raise RuntimeError (f"The field `{ src } ` not found" )
@@ -217,6 +236,8 @@ def __iter__(self):
217236 if skip :
218237 continue
219238 yield self .process (features )
239+ if (self .limit is not None ) and (total - self .offset == self .limit ):
240+ return
220241
221242 def _make_batch (self , by_name , batch_size , seq_feature_name = None ):
222243 # Compute lengths.
@@ -277,14 +298,16 @@ class ShuffledDistributedDataset(torch.utils.data.IterableDataset):
277298 Args:
278299 parallelize: Parallel reading mode, either `records` (better granularity) or `files` (faster).
279300 """
280- def __init__ (self , dataset , rank = None , world_size = None , cache_size = None , parallelize = DEFAULT_PARALLELIZM , seed = 0 ):
301+ def __init__ (self , dataset , rank = None , world_size = None , cache_size = None , parallelize = DEFAULT_PARALLELIZM , seed = 0 ,
302+ drop_last = False ):
281303 super ().__init__ ()
282304 self .dataset = dataset
283305 self .rank = rank
284306 self .world_size = world_size
285307 self .cache_size = cache_size
286308 self .parallelize = parallelize
287309 self .seed = seed
310+ self .drop_last = drop_last
288311 self .epoch = 0
289312
290313 def _get_context (self ):
@@ -320,19 +343,34 @@ def _iter_shuffled_files(self, dataset, seed, rank, world_size):
320343 filenames = list (dataset .filenames )
321344 if not filenames :
322345 raise RuntimeError ("Empty dataset" )
323- root = os .path .commonprefix (filenames )
324- splits = [list () for _ in range (world_size )]
325- for filename in filenames :
326- splits [immutable_hash (os .path .relpath (filename , root )) % world_size ].append (filename )
327- if any ([len (split ) == 0 for split in splits ]):
328- if rank == 0 :
329- warnings .warn (f"Some workers got zero files, switch to record parallelizm" )
330- yield from self ._iter_shuffled_records (dataset , seed , rank , world_size )
331- return
332- dataset = dataset .replace_files (splits [rank ])
346+ rnd = Random (seed )
347+ rnd .shuffle (filenames )
348+ lengths = list (map (get_parquet_length , filenames ))
349+ records_per_worker = sum (lengths ) // world_size
350+ if records_per_worker == 0 :
351+ raise RuntimeError (f"Very small dataset for { world_size } workers" )
352+ offset = records_per_worker * rank
353+ skipped = 0
354+ accepted = 0
355+ selected_filenames = []
356+ for filename , length in zip (filenames , lengths ):
357+ if skipped + accepted + length <= offset :
358+ skipped += length
359+ elif accepted >= records_per_worker :
360+ break
361+ else :
362+ selected_filenames .append (filename )
363+ accepted += length - max (0 , offset - skipped - accepted )
364+ dataset = dataset .replace_files (selected_filenames ,
365+ offset = offset - skipped ,
366+ limit = records_per_worker if self .drop_last or rank != world_size - 1 else None )
333367 yield from self ._iter_shuffled_records_impl (dataset , seed )
334368
335369 def _iter_shuffled_records (self , dataset , seed , rank , world_size ):
370+ rnd = Random (seed )
371+ filenames = list (dataset .filenames )
372+ rnd .shuffle (filenames )
373+ dataset = dataset .replace_files (filenames )
336374 for i , item in enumerate (self ._iter_shuffled_records_impl (dataset , seed )):
337375 if i % world_size == rank :
338376 yield item
@@ -342,9 +380,6 @@ def _iter_shuffled_records_impl(self, dataset, seed):
342380 yield from dataset
343381 else :
344382 rnd = Random (seed )
345- filenames = list (dataset .filenames )
346- rnd .shuffle (filenames )
347- dataset = dataset .replace_files (filenames )
348383 cache = []
349384 for item in dataset :
350385 cache .append (item )
0 commit comments