|
20 | 20 | from time import sleep |
21 | 21 | from typing import Any, Dict, Optional |
22 | 22 | from unittest import mock |
| 23 | +from unittest.mock import patch |
23 | 24 |
|
24 | 25 | import numpy as np |
25 | 26 | import pytest |
|
41 | 42 | _replay_sampling, |
42 | 43 | ) |
43 | 44 | from litdata.streaming.item_loader import TokensLoader |
| 45 | +from litdata.streaming.reader import BinaryReader |
44 | 46 | from litdata.streaming.shuffle import FullShuffle, NoShuffle |
45 | 47 | from litdata.utilities import dataset_utilities as dataset_utilities_module |
46 | 48 | from litdata.utilities.dataset_utilities import load_index_file |
@@ -1459,3 +1461,47 @@ def test_dataset_with_mosaic_mds_data(tmpdir): |
1459 | 1461 | assert len(batch["image"]) == 4 |
1460 | 1462 | assert list(batch["class"]) == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3] |
1461 | 1463 | i += 1 |
| 1464 | + |
| 1465 | + |
| 1466 | +@pytest.mark.parametrize("shuffle", [True, False]) |
| 1467 | +def test_is_last_index_for_chunked_index_with_dataset(tmpdir, shuffle): |
| 1468 | + # Create a dataset with 50 items, 10 items per chunk |
| 1469 | + cache = Cache(str(tmpdir), chunk_size=10) |
| 1470 | + for i in range(50): |
| 1471 | + cache[i] = i |
| 1472 | + cache.done() |
| 1473 | + cache.merge() |
| 1474 | + |
| 1475 | + # List to store all ChunkedIndex objects passed to BinaryReader.read |
| 1476 | + chunked_indexes = [] |
| 1477 | + |
| 1478 | + # Patch the BinaryReader.read method to track the indices |
| 1479 | + original_read = BinaryReader.read |
| 1480 | + |
| 1481 | + # Create a mock function that will capture the indices but still call the original |
| 1482 | + def mock_read(self, index): |
| 1483 | + chunked_indexes.append(index) |
| 1484 | + return original_read(self, index) # Call the original read method |
| 1485 | + |
| 1486 | + # Patch the read method directly in the BinaryReader class |
| 1487 | + with patch("litdata.streaming.reader.BinaryReader.read", mock_read): |
| 1488 | + dataset = StreamingDataset(str(tmpdir), shuffle=shuffle) |
| 1489 | + assert len(dataset) == 50 |
| 1490 | + |
| 1491 | + # Iterate through the dataset to trigger BinaryReader.read |
| 1492 | + for _ in dataset: |
| 1493 | + pass |
| 1494 | + |
| 1495 | + # Assertions |
| 1496 | + # Ensure BinaryReader.read was called 50 times (once for each item) |
| 1497 | + assert len(chunked_indexes) == 50, "Expected 50 calls to BinaryReader.read" |
| 1498 | + |
| 1499 | + # first chunked index has the chunk_indexes from dataset worker |
| 1500 | + worker_chunks = chunked_indexes[0].chunk_indexes |
| 1501 | + assert worker_chunks == dataset.worker_chunks, "Expected chunk_indexes to match dataset.worker_chunks" |
| 1502 | + |
| 1503 | + # Verify that exactly one index has is_last_index=True |
| 1504 | + indexes = [idx for idx in chunked_indexes if idx.is_last_index] |
| 1505 | + assert len(indexes) == 1, "Expected exactly one index with is_last_index=True" |
| 1506 | + assert indexes[0].is_last_index, "Expected is_last_index=True for the last item" |
| 1507 | + assert indexes[0].chunk_index == worker_chunks[-1], "Expected to match the last chunk" |
0 commit comments