File tree Expand file tree Collapse file tree 2 files changed +53
-2
lines changed
Expand file tree Collapse file tree 2 files changed +53
-2
lines changed Original file line number Diff line number Diff line change @@ -223,10 +223,32 @@ def on_demand_bytes(self, value: bool) -> None:
223223 self .cache ._reader .on_demand_bytes = value
224224
225225 def set_shuffle (self , shuffle : bool ) -> None :
226- self .shuffle = shuffle
226+ """Set the shuffle parameter.
227+
228+ Invalidates the shuffler cache when the parameter changes to ensure
229+ subsequent length calculations reflect the new shuffle setting.
230+
231+ Args:
232+ shuffle: Whether to shuffle the dataset.
233+
234+ """
235+ if self .shuffle != shuffle :
236+ self .shuffle = shuffle
237+ self .shuffler = None # Reset shuffler to pick up new shuffle setting
227238
228239 def set_drop_last (self , drop_last : bool ) -> None :
229- self .drop_last = drop_last
240+ """Set the drop_last parameter.
241+
242+ Invalidates the shuffler cache when the parameter changes to ensure
243+ subsequent length calculations reflect the new drop_last setting.
244+
245+ Args:
246+ drop_last: Whether to drop the last incomplete batch.
247+
248+ """
249+ if self .drop_last != drop_last :
250+ self .drop_last = drop_last
251+ self .shuffler = None # Reset shuffler to pick up new drop_last setting
230252
231253 def set_epoch (self , current_epoch : int ) -> None :
232254 """Set the current epoch to the dataset on epoch starts.
Original file line number Diff line number Diff line change @@ -544,6 +544,35 @@ def test_dataset_cache_recreation(tmpdir):
544544 assert dataset .shuffler is shuffler # shuffler gets reused
545545
546546
547+ @pytest .mark .timeout (30 )
548+ def test_len_called_before_dataloader_drop_last (tmpdir ):
549+ cache = Cache (str (tmpdir ), chunk_size = 10 )
550+ for i in range (100 ):
551+ cache [i ] = i
552+ cache .done ()
553+ cache .merge ()
554+
555+ dataset = StreamingDataset (input_dir = str (tmpdir ), shuffle = False )
556+ _ = len (dataset )
557+
558+ batch_size = 8
559+ dataloader = StreamingDataLoader (
560+ dataset ,
561+ batch_size = batch_size ,
562+ num_workers = 4 ,
563+ drop_last = True ,
564+ shuffle = False ,
565+ )
566+
567+ expected_batches = len (dataloader )
568+ batches = list (dataloader )
569+
570+ # With drop_last=True and 100 items: 100 // 8 = 12 full batches (4 items dropped)
571+ assert expected_batches == 12
572+ assert len (batches ) == expected_batches
573+ assert all (len (batch ) == batch_size for batch in batches )
574+
575+
547576def test_dataset_for_text_tokens (tmpdir ):
548577 seed_everything (42 )
549578
You can’t perform that action at this time.
0 commit comments