diff --git a/fsspec/caching.py b/fsspec/caching.py index a3f7a1c9f..bc74ad241 100644 --- a/fsspec/caching.py +++ b/fsspec/caching.py @@ -7,6 +7,8 @@ import os import threading import warnings +from itertools import groupby +from operator import itemgetter from concurrent.futures import Future, ThreadPoolExecutor from typing import ( TYPE_CHECKING, @@ -161,21 +163,39 @@ def _fetch(self, start: int | None, end: int | None) -> bytes: return b"" start_block = start // self.blocksize end_block = end // self.blocksize - need = [i for i in range(start_block, end_block + 1) if i not in self.blocks] - hits = [i for i in range(start_block, end_block + 1) if i in self.blocks] - self.miss_count += len(need) - self.hit_count += len(hits) - while need: - # TODO: not a for loop so we can consolidate blocks later to - # make fewer fetch calls; this could be parallel - i = need.pop(0) - - sstart = i * self.blocksize - send = min(sstart + self.blocksize, self.size) + block_range = range(start_block, end_block + 1) + # Determine which blocks need to be fetched. This sequence is sorted by construction. + need = (i for i in block_range if i not in self.blocks) + # Count the number of blocks already cached + self.hit_count += sum(1 for i in block_range if i in self.blocks) + + # Consolidate needed blocks. + # Algorithm adapted from Python 2.x itertools documentation. + # We are grouping an enumerated sequence of blocks. By comparing when the difference + # between an ascending range (provided by enumerate) and the needed block numbers + # we can detect when the block number skips values. The key computes this difference. + # Whenever the difference changes, we know that we have previously cached block(s), + # and a new group is started. In other words, this algorithm neatly groups + # runs of consecutive block numbers so they can be fetched together. + for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]): + # Extract the blocks from the enumerated sequence + _blocks = tuple(map(itemgetter(1), _blocks)) + # Compute start of first block + sstart = _blocks[0] * self.blocksize + # Compute the end of the last block. Last block may not be full size. + send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size) + + # Fetch bytes (could be multiple consecutive blocks) self.total_requested_bytes += send - sstart - logger.debug(f"MMap get block #{i} ({sstart}-{send})") + logger.debug( + f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})" + ) self.cache[sstart:send] = self.fetcher(sstart, send) - self.blocks.add(i) + + # Update set of cached blocks + self.blocks.update(_blocks) + # Update cache statistics with number of blocks we had to cache + self.miss_count += len(_blocks) return self.cache[start:end]