Skip to content

Commit c2252a6

Browse files
authored
Fix StreamingDataset len after drop_last update (#778)
1 parent c864e14 commit c2252a6

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

src/litdata/streaming/dataset.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,32 @@ def on_demand_bytes(self, value: bool) -> None:
223223
self.cache._reader.on_demand_bytes = value
224224

225225
def set_shuffle(self, shuffle: bool) -> None:
226-
self.shuffle = shuffle
226+
"""Set the shuffle parameter.
227+
228+
Invalidates the shuffler cache when the parameter changes to ensure
229+
subsequent length calculations reflect the new shuffle setting.
230+
231+
Args:
232+
shuffle: Whether to shuffle the dataset.
233+
234+
"""
235+
if self.shuffle != shuffle:
236+
self.shuffle = shuffle
237+
self.shuffler = None # Reset shuffler to pick up new shuffle setting
227238

228239
def set_drop_last(self, drop_last: bool) -> None:
229-
self.drop_last = drop_last
240+
"""Set the drop_last parameter.
241+
242+
Invalidates the shuffler cache when the parameter changes to ensure
243+
subsequent length calculations reflect the new drop_last setting.
244+
245+
Args:
246+
drop_last: Whether to drop the last incomplete batch.
247+
248+
"""
249+
if self.drop_last != drop_last:
250+
self.drop_last = drop_last
251+
self.shuffler = None # Reset shuffler to pick up new drop_last setting
230252

231253
def set_epoch(self, current_epoch: int) -> None:
232254
"""Set the current epoch to the dataset on epoch starts.

tests/streaming/test_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,35 @@ def test_dataset_cache_recreation(tmpdir):
544544
assert dataset.shuffler is shuffler # shuffler gets reused
545545

546546

547+
@pytest.mark.timeout(30)
548+
def test_len_called_before_dataloader_drop_last(tmpdir):
549+
cache = Cache(str(tmpdir), chunk_size=10)
550+
for i in range(100):
551+
cache[i] = i
552+
cache.done()
553+
cache.merge()
554+
555+
dataset = StreamingDataset(input_dir=str(tmpdir), shuffle=False)
556+
_ = len(dataset)
557+
558+
batch_size = 8
559+
dataloader = StreamingDataLoader(
560+
dataset,
561+
batch_size=batch_size,
562+
num_workers=4,
563+
drop_last=True,
564+
shuffle=False,
565+
)
566+
567+
expected_batches = len(dataloader)
568+
batches = list(dataloader)
569+
570+
# With drop_last=True and 100 items: 100 // 8 = 12 full batches (4 items dropped)
571+
assert expected_batches == 12
572+
assert len(batches) == expected_batches
573+
assert all(len(batch) == batch_size for batch in batches)
574+
575+
547576
def test_dataset_for_text_tokens(tmpdir):
548577
seed_everything(42)
549578

0 commit comments

Comments
 (0)