|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +from contextlib import suppress |
| 4 | + |
| 5 | +import pytest |
| 6 | +from filelock import FileLock, Timeout |
| 7 | + |
| 8 | +from litdata.constants import _ZSTD_AVAILABLE |
| 9 | +from litdata.streaming.cache import Cache |
| 10 | +from litdata.streaming.config import ChunkedIndex |
| 11 | +from litdata.streaming.downloader import LocalDownloader, register_downloader, unregister_downloader |
| 12 | +from litdata.streaming.reader import BinaryReader |
| 13 | +from litdata.streaming.resolver import Dir |
| 14 | + |
| 15 | + |
| 16 | +class LocalDownloaderNoLockCleanup(LocalDownloader): |
| 17 | + """A Local downloader variant that does NOT remove the `.lock` file after download. |
| 18 | +
|
| 19 | + This simulates behavior of non-local downloaders where the lockfile persists on disk |
| 20 | + until Reader cleanup runs. Used to verify our centralized lock cleanup. |
| 21 | + """ |
| 22 | + |
| 23 | + def download_file(self, remote_filepath: str, local_filepath: str) -> None: # type: ignore[override] |
| 24 | + # Strip the custom scheme used for testing to map to local FS |
| 25 | + if remote_filepath.startswith("s3+local://"): |
| 26 | + remote_filepath = remote_filepath.replace("s3+local://", "") |
| 27 | + if not os.path.exists(remote_filepath): |
| 28 | + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") |
| 29 | + |
| 30 | + with ( |
| 31 | + suppress(Timeout, FileNotFoundError), |
| 32 | + FileLock(local_filepath + ".lock", timeout=0), |
| 33 | + ): |
| 34 | + if remote_filepath == local_filepath or os.path.exists(local_filepath): |
| 35 | + return |
| 36 | + temp_file_path = local_filepath + ".tmp" |
| 37 | + shutil.copy(remote_filepath, temp_file_path) |
| 38 | + os.rename(temp_file_path, local_filepath) |
| 39 | + # Intentionally do NOT remove `local_filepath + ".lock"` here |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.skipif(not _ZSTD_AVAILABLE, reason="Requires: ['zstd']") |
| 43 | +def test_reader_lock_cleanup_with_nonlocal_like_downloader(tmpdir): |
| 44 | + cache_dir = os.path.join(tmpdir, "cache_dir") |
| 45 | + remote_dir = os.path.join(tmpdir, "remote_dir") |
| 46 | + os.makedirs(cache_dir, exist_ok=True) |
| 47 | + |
| 48 | + # Build a small compressed dataset |
| 49 | + cache = Cache(input_dir=Dir(path=cache_dir, url=None), chunk_size=3, compression="zstd") |
| 50 | + for i in range(10): |
| 51 | + cache[i] = i |
| 52 | + cache.done() |
| 53 | + cache.merge() |
| 54 | + |
| 55 | + # Copy to a "remote" directory |
| 56 | + shutil.copytree(cache_dir, remote_dir) |
| 57 | + |
| 58 | + # Use a custom scheme that we register to our test downloader |
| 59 | + prefix = "s3+local://" |
| 60 | + remote_url = prefix + remote_dir |
| 61 | + |
| 62 | + # Register the downloader and ensure we unregister afterwards |
| 63 | + register_downloader(prefix, LocalDownloaderNoLockCleanup, overwrite=True) |
| 64 | + try: |
| 65 | + # Fresh cache dir for reading |
| 66 | + shutil.rmtree(cache_dir) |
| 67 | + os.makedirs(cache_dir, exist_ok=True) |
| 68 | + |
| 69 | + reader = BinaryReader(cache_dir=cache_dir, remote_input_dir=remote_url, compression="zstd", max_cache_size=1) |
| 70 | + |
| 71 | + # Iterate across enough samples to trigger multiple chunk downloads and deletions |
| 72 | + for i in range(10): |
| 73 | + idx = reader._get_chunk_index_from_index(i) |
| 74 | + chunk_idx = ChunkedIndex(index=idx[0], chunk_index=idx[1], is_last_index=(i == 9)) |
| 75 | + reader.read(chunk_idx) |
| 76 | + |
| 77 | + # At the end, no chunk-related lock files should remain |
| 78 | + leftover_locks = [f for f in os.listdir(cache_dir) if f.endswith(".lock") and f.startswith("chunk-")] |
| 79 | + assert leftover_locks == [] |
| 80 | + finally: |
| 81 | + unregister_downloader(prefix) |
0 commit comments