Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion fsspec/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
logger = logging.getLogger("fsspec")

Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
MultiFetcher = Callable[list[[int, int]], bytes] # Maps [(start, end)] to bytes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @monken @martindurant,

Unfortunately, this change breaks the usage of this library in Python 3.8. It's simple to reproduce:

from typing import Callable
X = Callable[list[[int, int]], bytes]

Would it be possible to yank 2025.3.1 (at least while support for 3.8 is still being maintained)? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will yank, but I 2025.3.1 will henceforth only support >=3.9; it should have probably been done before now.



class BaseCache:
Expand Down Expand Up @@ -109,6 +110,26 @@ class MMapCache(BaseCache):
Ensure there is enough disc space in the temporary location.

This cache method might only work on posix

Parameters
----------
blocksize: int
How far to read ahead in numbers of bytes
fetcher: Fetcher
Function of the form f(start, end) which gets bytes from remote as
specified
size: int
How big this file is
location: str
Where to create the temporary file. If None, a temporary file is
created using tempfile.TemporaryFile().
blocks: set[int]
Set of block numbers that have already been fetched. If None, an empty
set is created.
multi_fetcher: MultiFetcher
Function of the form f([(start, end)]) which gets bytes from remote
as specified. This function is used to fetch multiple blocks at once.
If not specified, the fetcher function is used instead.
"""

name = "mmap"
Expand All @@ -120,10 +141,12 @@ def __init__(
size: int,
location: str | None = None,
blocks: set[int] | None = None,
multi_fetcher: MultiFetcher | None = None,
) -> None:
super().__init__(blocksize, fetcher, size)
self.blocks = set() if blocks is None else blocks
self.location = location
self.multi_fetcher = multi_fetcher
self.cache = self._makefile()

def _makefile(self) -> mmap.mmap | bytearray:
Expand Down Expand Up @@ -164,6 +187,8 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
# Count the number of blocks already cached
self.hit_count += sum(1 for i in block_range if i in self.blocks)

ranges = []

# Consolidate needed blocks.
# Algorithm adapted from Python 2.x itertools documentation.
# We are grouping an enumerated sequence of blocks. By comparing when the difference
Expand All @@ -185,13 +210,27 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
logger.debug(
f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
)
self.cache[sstart:send] = self.fetcher(sstart, send)
ranges.append((sstart, send))

# Update set of cached blocks
self.blocks.update(_blocks)
# Update cache statistics with number of blocks we had to cache
self.miss_count += len(_blocks)

if not ranges:
return self.cache[start:end]

if self.multi_fetcher:
logger.debug(f"MMap get blocks {ranges}")
for idx, r in enumerate(self.multi_fetcher(ranges)):
(sstart, send) = ranges[idx]
logger.debug(f"MMap copy block ({sstart}-{send}")
self.cache[sstart:send] = r
else:
for sstart, send in ranges:
logger.debug(f"MMap get block ({sstart}-{send}")
self.cache[sstart:send] = self.fetcher(sstart, send)

return self.cache[start:end]

def __getstate__(self) -> dict[str, Any]:
Expand Down
14 changes: 13 additions & 1 deletion fsspec/implementations/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,19 @@ def _open(
)
else:
detail["blocksize"] = f.blocksize
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)

def _fetch_ranges(ranges):
return self.fs.cat_ranges(
[path] * len(ranges),
[r[0] for r in ranges],
[r[1] for r in ranges],
**kwargs,
)

multi_fetcher = None if self.compression else _fetch_ranges
f.cache = MMapCache(
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
)
close = f.close
f.close = lambda: self.close_and_update(f, close)
self.save_cache()
Expand Down
37 changes: 37 additions & 0 deletions fsspec/tests/test_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec.caching import (
BlockCache,
FirstChunkCache,
MMapCache,
ReadAheadCache,
caches,
register_cache,
Expand Down Expand Up @@ -147,6 +148,10 @@ def letters_fetcher(start, end):
return string.ascii_letters[start:end].encode()


def multi_letters_fetcher(ranges):
return [string.ascii_letters[start:end].encode() for start, end in ranges]


not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}


Expand Down Expand Up @@ -174,6 +179,38 @@ def test_cache_pickleable(Cache_imp):
assert unpickled._fetch(0, 10) == b"0" * 10


def test_first_cache():
c = FirstChunkCache(5, letters_fetcher, 52)
assert c.cache is None
assert c._fetch(12, 15) == letters_fetcher(12, 15)
assert c.cache is None
assert c._fetch(3, 10) == letters_fetcher(3, 10)
assert c.cache == letters_fetcher(0, 5)
c.fetcher = None
assert c._fetch(1, 4) == letters_fetcher(1, 4)


def test_mmap_cache(mocker):
fetcher = mocker.Mock(wraps=letters_fetcher)
c = MMapCache(5, fetcher, 52)
assert c._fetch(6, 8) == letters_fetcher(6, 8)
assert fetcher.call_count == 1
assert c._fetch(17, 22) == letters_fetcher(17, 22)
assert fetcher.call_count == 2
assert c._fetch(1, 38) == letters_fetcher(1, 38)
assert fetcher.call_count == 5

multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
assert m._fetch(6, 8) == letters_fetcher(6, 8)
assert multi_fetcher.call_count == 1
assert m._fetch(17, 22) == letters_fetcher(17, 22)
assert multi_fetcher.call_count == 2
assert m._fetch(1, 38) == letters_fetcher(1, 38)
assert multi_fetcher.call_count == 3
assert fetcher.call_count == 5


@pytest.mark.parametrize(
"size_requests",
[[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],
Expand Down
Loading