Skip to content

Commit 4ff18da

Browse files
Fixes the logic for is_last_index. (#531)
* fix the logic for last index * Add unit test for is_last_index in dataset for chunked indexes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cfac30a commit 4ff18da

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/litdata/streaming/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __next__(self) -> Any:
417417
chunk_index=self.worker_chunks[self.chunk_index - 1],
418418
# We provide the chunks indexes only one the first
419419
chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :],
420-
is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
420+
is_last_index=(self.chunk_index) == len(self.worker_intervals) and len(self.current_indexes) == 0,
421421
)
422422
)
423423

tests/streaming/test_dataset.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from time import sleep
2121
from typing import Any, Dict, Optional
2222
from unittest import mock
23+
from unittest.mock import patch
2324

2425
import numpy as np
2526
import pytest
@@ -41,6 +42,7 @@
4142
_replay_sampling,
4243
)
4344
from litdata.streaming.item_loader import TokensLoader
45+
from litdata.streaming.reader import BinaryReader
4446
from litdata.streaming.shuffle import FullShuffle, NoShuffle
4547
from litdata.utilities import dataset_utilities as dataset_utilities_module
4648
from litdata.utilities.dataset_utilities import load_index_file
@@ -1459,3 +1461,47 @@ def test_dataset_with_mosaic_mds_data(tmpdir):
14591461
assert len(batch["image"]) == 4
14601462
assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
14611463
i += 1
1464+
1465+
1466+
@pytest.mark.parametrize("shuffle", [True, False])
1467+
def test_is_last_index_for_chunked_index_with_dataset(tmpdir, shuffle):
1468+
# Create a dataset with 50 items, 10 items per chunk
1469+
cache = Cache(str(tmpdir), chunk_size=10)
1470+
for i in range(50):
1471+
cache[i] = i
1472+
cache.done()
1473+
cache.merge()
1474+
1475+
# List to store all ChunkedIndex objects passed to BinaryReader.read
1476+
chunked_indexes = []
1477+
1478+
# Patch the BinaryReader.read method to track the indices
1479+
original_read = BinaryReader.read
1480+
1481+
# Create a mock function that will capture the indices but still call the original
1482+
def mock_read(self, index):
1483+
chunked_indexes.append(index)
1484+
return original_read(self, index) # Call the original read method
1485+
1486+
# Patch the read method directly in the BinaryReader class
1487+
with patch("litdata.streaming.reader.BinaryReader.read", mock_read):
1488+
dataset = StreamingDataset(str(tmpdir), shuffle=shuffle)
1489+
assert len(dataset) == 50
1490+
1491+
# Iterate through the dataset to trigger BinaryReader.read
1492+
for _ in dataset:
1493+
pass
1494+
1495+
# Assertions
1496+
# Ensure BinaryReader.read was called 50 times (once for each item)
1497+
assert len(chunked_indexes) == 50, "Expected 50 calls to BinaryReader.read"
1498+
1499+
# first chunked index has the chunk_indexes from dataset worker
1500+
worker_chunks = chunked_indexes[0].chunk_indexes
1501+
assert worker_chunks == dataset.worker_chunks, "Expected chunk_indexes to match dataset.worker_chunks"
1502+
1503+
# Verify that exactly one index has is_last_index=True
1504+
indexes = [idx for idx in chunked_indexes if idx.is_last_index]
1505+
assert len(indexes) == 1, "Expected exactly one index with is_last_index=True"
1506+
assert indexes[0].is_last_index, "Expected is_last_index=True for the last item"
1507+
assert indexes[0].chunk_index == worker_chunks[-1], "Expected to match the last chunk"

0 commit comments

Comments
 (0)