diff --git a/fsspec/caching.py b/fsspec/caching.py index 34231c881..3be267626 100644 --- a/fsspec/caching.py +++ b/fsspec/caching.py @@ -37,6 +37,7 @@ logger = logging.getLogger("fsspec") Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes +MultiFetcher = Callable[list[[int, int]], bytes] # Maps [(start, end)] to bytes class BaseCache: @@ -109,6 +110,26 @@ class MMapCache(BaseCache): Ensure there is enough disc space in the temporary location. This cache method might only work on posix + + Parameters + ---------- + blocksize: int + How far to read ahead in numbers of bytes + fetcher: Fetcher + Function of the form f(start, end) which gets bytes from remote as + specified + size: int + How big this file is + location: str + Where to create the temporary file. If None, a temporary file is + created using tempfile.TemporaryFile(). + blocks: set[int] + Set of block numbers that have already been fetched. If None, an empty + set is created. + multi_fetcher: MultiFetcher + Function of the form f([(start, end)]) which gets bytes from remote + as specified. This function is used to fetch multiple blocks at once. + If not specified, the fetcher function is used instead. """ name = "mmap" @@ -120,10 +141,12 @@ def __init__( size: int, location: str | None = None, blocks: set[int] | None = None, + multi_fetcher: MultiFetcher | None = None, ) -> None: super().__init__(blocksize, fetcher, size) self.blocks = set() if blocks is None else blocks self.location = location + self.multi_fetcher = multi_fetcher self.cache = self._makefile() def _makefile(self) -> mmap.mmap | bytearray: @@ -164,6 +187,8 @@ def _fetch(self, start: int | None, end: int | None) -> bytes: # Count the number of blocks already cached self.hit_count += sum(1 for i in block_range if i in self.blocks) + ranges = [] + # Consolidate needed blocks. # Algorithm adapted from Python 2.x itertools documentation. # We are grouping an enumerated sequence of blocks. By comparing when the difference @@ -185,13 +210,27 @@ def _fetch(self, start: int | None, end: int | None) -> bytes: logger.debug( f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})" ) - self.cache[sstart:send] = self.fetcher(sstart, send) + ranges.append((sstart, send)) # Update set of cached blocks self.blocks.update(_blocks) # Update cache statistics with number of blocks we had to cache self.miss_count += len(_blocks) + if not ranges: + return self.cache[start:end] + + if self.multi_fetcher: + logger.debug(f"MMap get blocks {ranges}") + for idx, r in enumerate(self.multi_fetcher(ranges)): + (sstart, send) = ranges[idx] + logger.debug(f"MMap copy block ({sstart}-{send}") + self.cache[sstart:send] = r + else: + for sstart, send in ranges: + logger.debug(f"MMap get block ({sstart}-{send}") + self.cache[sstart:send] = self.fetcher(sstart, send) + return self.cache[start:end] def __getstate__(self) -> dict[str, Any]: diff --git a/fsspec/implementations/cached.py b/fsspec/implementations/cached.py index 8b8b83295..bd58ad496 100644 --- a/fsspec/implementations/cached.py +++ b/fsspec/implementations/cached.py @@ -362,7 +362,19 @@ def _open( ) else: detail["blocksize"] = f.blocksize - f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks) + + def _fetch_ranges(ranges): + return self.fs.cat_ranges( + [path] * len(ranges), + [r[0] for r in ranges], + [r[1] for r in ranges], + **kwargs, + ) + + multi_fetcher = None if self.compression else _fetch_ranges + f.cache = MMapCache( + f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher + ) close = f.close f.close = lambda: self.close_and_update(f, close) self.save_cache() diff --git a/fsspec/tests/test_caches.py b/fsspec/tests/test_caches.py index 6176d8001..5299a0824 100644 --- a/fsspec/tests/test_caches.py +++ b/fsspec/tests/test_caches.py @@ -6,6 +6,7 @@ from fsspec.caching import ( BlockCache, FirstChunkCache, + MMapCache, ReadAheadCache, caches, register_cache, @@ -147,6 +148,10 @@ def letters_fetcher(start, end): return string.ascii_letters[start:end].encode() +def multi_letters_fetcher(ranges): + return [string.ascii_letters[start:end].encode() for start, end in ranges] + + not_parts_caches = {k: v for k, v in caches.items() if k != "parts"} @@ -174,6 +179,38 @@ def test_cache_pickleable(Cache_imp): assert unpickled._fetch(0, 10) == b"0" * 10 +def test_first_cache(): + c = FirstChunkCache(5, letters_fetcher, 52) + assert c.cache is None + assert c._fetch(12, 15) == letters_fetcher(12, 15) + assert c.cache is None + assert c._fetch(3, 10) == letters_fetcher(3, 10) + assert c.cache == letters_fetcher(0, 5) + c.fetcher = None + assert c._fetch(1, 4) == letters_fetcher(1, 4) + + +def test_mmap_cache(mocker): + fetcher = mocker.Mock(wraps=letters_fetcher) + c = MMapCache(5, fetcher, 52) + assert c._fetch(6, 8) == letters_fetcher(6, 8) + assert fetcher.call_count == 1 + assert c._fetch(17, 22) == letters_fetcher(17, 22) + assert fetcher.call_count == 2 + assert c._fetch(1, 38) == letters_fetcher(1, 38) + assert fetcher.call_count == 5 + + multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher) + m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher) + assert m._fetch(6, 8) == letters_fetcher(6, 8) + assert multi_fetcher.call_count == 1 + assert m._fetch(17, 22) == letters_fetcher(17, 22) + assert multi_fetcher.call_count == 2 + assert m._fetch(1, 38) == letters_fetcher(1, 38) + assert multi_fetcher.call_count == 3 + assert fetcher.call_count == 5 + + @pytest.mark.parametrize( "size_requests", [[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],