Skip to content

Commit ee03383

Browse files
Fix: redundant chunk index download request in BinaryReader , when dataset in iter mode (#535)
* Remove redundant chunk index download request in BinaryReader * update the condition * Reset last chunk index and queued download state on close * add test case for dataset as iterator and non iterator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * update comment for clarity on chunk download conditions --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4ff18da commit ee03383

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

src/litdata/streaming/reader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
self._prepare_thread: Optional[PrepareChunksThread] = None
300300
self._item_loader = item_loader or PyTreeLoader()
301301
self._last_chunk_index: Optional[int] = None
302+
self._chunks_queued_for_download = False
302303
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))
303304
self._storage_options = storage_options
304305
self._max_pre_download = max_pre_download
@@ -368,9 +369,12 @@ def read(self, index: ChunkedIndex) -> Any:
368369
self._prepare_thread.start()
369370
if index.chunk_indexes:
370371
self._prepare_thread.download(index.chunk_indexes)
372+
self._chunks_queued_for_download = True
371373

372-
# If the chunk_index is new, request for it to be downloaded.
373-
if index.chunk_index != self._last_chunk_index:
374+
# Only request individual chunk download if:
375+
# 1. We haven't already queued all chunks for the download
376+
# 2. We're processing a new chunk (different from the last one)
377+
if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index:
374378
assert self._prepare_thread
375379
self._prepare_thread.download([index.chunk_index])
376380

@@ -417,6 +421,8 @@ def read(self, index: ChunkedIndex) -> Any:
417421
self._prepare_thread.stop()
418422
self._prepare_thread = None
419423
self._item_loader.close(self._last_chunk_index)
424+
self._last_chunk_index = None
425+
self._chunks_queued_for_download = False
420426

421427
return item
422428

tests/streaming/test_dataset.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,3 +1505,57 @@ def mock_read(self, index):
15051505
assert len(indexes) == 1, "Expected exactly one index with is_last_index=True"
15061506
assert indexes[0].is_last_index, "Expected is_last_index=True for the last item"
15071507
assert indexes[0].chunk_index == worker_chunks[-1], "Expected to match the last chunk"
1508+
1509+
1510+
@pytest.mark.parametrize("local", [True, False])
1511+
@pytest.mark.parametrize("shuffle", [True, False])
1512+
def test_dataset_as_iterator_and_non_iterator(tmpdir, local, shuffle):
1513+
"""Test that _chunks_queued_for_download flag is correctly set and reset in reader.
1514+
1515+
This test verifies that:
1516+
1. When iterating, _chunks_queued_for_download is enabled during iteration but reset when done
1517+
2. When accessing by index, _chunks_queued_for_download is never enabled
1518+
"""
1519+
# Create directories
1520+
cache_dir = os.path.join(tmpdir, "cache_dir")
1521+
data_dir = os.path.join(tmpdir, "data_dir")
1522+
os.makedirs(cache_dir)
1523+
os.makedirs(data_dir)
1524+
1525+
# Create a dataset with 50 items, 10 items per chunk
1526+
cache = Cache(str(data_dir), chunk_size=10)
1527+
for i in range(50):
1528+
cache[i] = i
1529+
cache.done()
1530+
cache.merge()
1531+
1532+
# Create dataset with appropriate configuration
1533+
input_dir = f"local:{data_dir}" if local else str(data_dir)
1534+
dataset = StreamingDataset(input_dir, cache_dir=str(cache_dir) if local else None, shuffle=shuffle)
1535+
dataset_length = len(dataset)
1536+
assert dataset_length == 50
1537+
1538+
# ACT & ASSERT - Test iterator mode
1539+
for i, data in enumerate(dataset):
1540+
assert data is not None
1541+
if local and i < dataset_length - 1:
1542+
# In iterator mode with local or remote data, _chunks_queued_for_download should be enabled
1543+
assert (
1544+
dataset.cache._reader._chunks_queued_for_download is True
1545+
), "_chunks_queued_for_download should be enabled during iteration"
1546+
else:
1547+
assert dataset.cache._reader._chunks_queued_for_download is False, (
1548+
"_chunks_queued_for_download should be disabled when used as local dir without `local:` prefix"
1549+
" or when iteration is done"
1550+
)
1551+
# After iteration, _chunks_queued_for_download should be reset
1552+
assert dataset.cache._reader._chunks_queued_for_download is False
1553+
1554+
# ACT & ASSERT - Test indexed access mode
1555+
for i in range(dataset_length):
1556+
data = dataset[i]
1557+
assert data is not None
1558+
# In indexed access mode, _chunks_queued_for_download should never be enabled
1559+
assert dataset.cache._reader._chunks_queued_for_download is False
1560+
1561+
assert dataset.cache._reader._chunks_queued_for_download is False

0 commit comments

Comments
 (0)