diff --git a/cloudbuild/run_tests.sh b/cloudbuild/run_tests.sh index a970015d..505b00ee 100644 --- a/cloudbuild/run_tests.sh +++ b/cloudbuild/run_tests.sh @@ -117,6 +117,19 @@ case "$TEST_SUITE" in "--deselect=gcsfs/tests/test_core.py::test_mv_file_cache" ) + # The prefetcher engine is not integrated for zonal in this bucket. + # It will be integrated in a separate PR, after which this will be removed. + ZONAL_DESELECTS+=( + "--deselect=gcsfs/tests/test_core.py::test_cat_file_routing_and_thresholds" + "--deselect=gcsfs/tests/test_core.py::test_cat_file_concurrent_data_integrity" + "--deselect=gcsfs/tests/test_core.py::test_cat_file_concurrent_exception_cancellation" + "--deselect=gcsfs/tests/test_core.py::test_gcsfile_prefetch_disabled_fallback" + "--deselect=gcsfs/tests/test_core.py::test_gcsfile_prefetch_sequential_integrity" + "--deselect=gcsfs/tests/test_core.py::test_gcsfile_prefetch_random_seek_integrity" + "--deselect=gcsfs/tests/test_core.py::test_gcsfile_multithreaded_read_integrity" + "--deselect=gcsfs/tests/test_core.py::test_gcsfile_not_satisfiable_range" + ) + pytest "${ARGS[@]}" "${ZONAL_DESELECTS[@]}" gcsfs/tests/test_core.py ;; esac diff --git a/gcsfs/core.py b/gcsfs/core.py index 916437e4..5ed147b2 100644 --- a/gcsfs/core.py +++ b/gcsfs/core.py @@ -30,6 +30,7 @@ from .credentials import GoogleCredentials from .inventory_report import InventoryReport from .retry import errs, retry_request, validate_response +from .zb_hns_utils import DEFAULT_CONCURRENCY logger = logging.getLogger("gcsfs") @@ -299,6 +300,7 @@ class GCSFileSystem(asyn.AsyncFileSystem): default_block_size = DEFAULT_BLOCK_SIZE protocol = "gs", "gcs" async_impl = True + MIN_CHUNK_SIZE_FOR_CONCURRENCY = 5 * 1024 * 1024 def __init__( self, @@ -1166,22 +1168,75 @@ def url(self, path): f"&generation={generation}" if generation else "", ) - async def _cat_file(self, path, start=None, end=None, **kwargs): + async def _cat_file_sequential(self, path, start=None, end=None, **kwargs): """Simple one-shot get of file data""" # if start and end are both provided and valid, but start >= end, return empty bytes # Otherwise, _process_limits would generate an invalid HTTP range (e.g. "bytes=5-4" # for start=5, end=5), causing the server to return the whole file instead of nothing. if start is not None and end is not None and start >= end >= 0: return b"" + u2 = self.url(path) - # 'if start or end' fails when start=0 or end=0 because 0 is Falsey. if start is not None or end is not None: head = {"Range": await self._process_limits(path, start, end)} else: head = {} + headers, out = await self._call("GET", u2, headers=head) return out + async def _cat_file_concurrent( + self, path, start=None, end=None, concurrency=DEFAULT_CONCURRENCY, **kwargs + ): + """Concurrent fetch of file data""" + if start is None: + start = 0 + if end is None: + end = (await self._info(path))["size"] + if start >= end: + return b"" + + if concurrency <= 1 or end - start < self.MIN_CHUNK_SIZE_FOR_CONCURRENCY: + return await self._cat_file_sequential(path, start=start, end=end, **kwargs) + + total_size = end - start + part_size = total_size // concurrency + tasks = [] + + for i in range(concurrency): + offset = start + (i * part_size) + actual_size = ( + part_size if i < concurrency - 1 else total_size - (i * part_size) + ) + tasks.append( + asyncio.create_task( + self._cat_file_sequential( + path, start=offset, end=offset + actual_size, **kwargs + ) + ) + ) + + try: + results = await asyncio.gather(*tasks) + return b"".join(results) + except BaseException as e: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise e + + async def _cat_file( + self, path, start=None, end=None, concurrency=DEFAULT_CONCURRENCY, **kwargs + ): + """Simple one-shot, or concurrent get of file data""" + if concurrency > 1: + return await self._cat_file_concurrent( + path, start=start, end=end, concurrency=concurrency, **kwargs + ) + + return await self._cat_file_sequential(path, start=start, end=end, **kwargs) + async def _getxattr(self, path, attr): """Get user-defined metadata attribute""" meta = (await self._info(path)).get("metadata", {}) @@ -2020,6 +2075,38 @@ def __init__( self.acl = acl self.consistency = consistency self.checker = get_consistency_checker(consistency) + + # Ideally, all of these fields should be part of `cache_options`. Because current + # `fsspec` caches do not accept arbitrary `*args` and `**kwargs`, passing them + # there currently causes instantiation errors. We are holding off on introducing + # them as explicit keyword arguments to ensure existing user workloads are not + # disrupted. This will be refactored once the upstream `fsspec` changes are merged. + use_prefetch_reader = kwargs.get( + "use_experimental_adaptive_prefetching", False + ) or os.environ.get( + "use_experimental_adaptive_prefetching", "false" + ).lower() in ( + "true", + "1", + "t", + "y", + "yes", + ) + self.concurrency = kwargs.get("concurrency", DEFAULT_CONCURRENCY) + + if use_prefetch_reader: + max_prefetch_size = kwargs.get("max_prefetch_size", None) + from .prefetcher import BackgroundPrefetcher + + self._prefetch_engine = BackgroundPrefetcher( + self._async_fetch_range, + self.size, + max_prefetch_size=max_prefetch_size, + concurrency=self.concurrency, + ) + else: + self._prefetch_engine = None + # _supports_append is an internal argument not meant to be used directly. # If True, allows opening file in append mode. This is generally not supported # by GCS, but may be supported by subclasses (e.g. ZonalFile). This flag should @@ -2202,12 +2289,30 @@ def _fetch_range(self, start=None, end=None): if not both None, fetch only given range """ try: - return self.gcsfs.cat_file(self.path, start=start, end=end) + if self._prefetch_engine: + return self._prefetch_engine._fetch(start=start, end=end) + return self.gcsfs.cat_file( + self.path, start=start, end=end, concurrency=self.concurrency + ) except RuntimeError as e: if "not satisfiable" in str(e): return b"" raise + async def _async_fetch_range(self, start_offset, total_size, split_factor=1): + """Async fetcher mapped to the Prefetcher engine for regional buckets.""" + return await self.gcsfs._cat_file_concurrent( + self.path, + start=start_offset, + end=start_offset + total_size, + concurrency=split_factor, + ) + + def close(self): + if hasattr(self, "_prefetch_engine") and self._prefetch_engine: + self._prefetch_engine.close() + super().close() + def _convert_fixed_key_metadata(metadata, *, from_google=False): """ diff --git a/gcsfs/prefetcher.py b/gcsfs/prefetcher.py new file mode 100644 index 00000000..74617d96 --- /dev/null +++ b/gcsfs/prefetcher.py @@ -0,0 +1,676 @@ +import asyncio +import ctypes +import logging +import threading +from collections import deque + +import fsspec.asyn + +logger = logging.getLogger(__name__) + +PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize +PyBytes_FromStringAndSize.restype = ctypes.py_object +PyBytes_FromStringAndSize.argtypes = [ctypes.c_void_p, ctypes.c_ssize_t] + +PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString +PyBytes_AsString.restype = ctypes.c_void_p +PyBytes_AsString.argtypes = [ctypes.py_object] + + +def _fast_slice(src_bytes, offset, read_size): + if read_size == 0: + return b"" + dest_bytes = PyBytes_FromStringAndSize(None, read_size) + src_ptr = PyBytes_AsString(src_bytes) + dest_ptr = PyBytes_AsString(dest_bytes) + + ctypes.memmove(dest_ptr, src_ptr + offset, read_size) + return dest_bytes + + +class RunningAverageTracker: + """Tracks a running average of values over a sliding window. + + This is used to monitor read sizes and adaptively scale the + prefetching strategy based on recent user behavior. + """ + + def __init__(self, maxlen=10): + """Initializes the tracker with a specific window size. + + Args: + maxlen (int): The maximum number of historical values to keep. + """ + logger.debug("Initializing RunningAverageTracker with maxlen: %d", maxlen) + self._history = deque(maxlen=maxlen) + self._sum = 0 + + def add(self, value: int): + """Adds a new value to the sliding window and updates the rolling sum. + + Args: + value (int): The integer value to add to the history. + """ + if value <= 0: + return + if len(self._history) == self._history.maxlen: + self._sum -= self._history[0] + + self._history.append(value) + self._sum += value + logger.debug( + "RunningAverageTracker added value: %d, new sum: %d", value, self._sum + ) + + @property + def average(self) -> int: + """Calculates and returns the current running average. + + Returns: + int: The integer average of the current history. + """ + count = len(self._history) + if count == 0: + return 1024 * 1024 # 1MB + return self._sum // count + + def clear(self): + """Clears the history and resets the sum to zero.""" + logger.debug("Clearing RunningAverageTracker history.") + self._history.clear() + self._sum = 0 + + +class PrefetchProducer: + """Background worker that fetches sequential blocks of data. + + This class handles the network requests. It spawns asynchronous tasks + to fetch data ahead of the user's current reading position and + places those task promises into a queue for the consumer. + """ + + # If the request is too small, and prefetch window is expanded till 5MB + # we then make request in 5MB blocks. + MIN_CHUNK_SIZE = 5 * 1024 * 1024 + + # If user doesn't specify any max_prefetch_size, the prefetcher defaults + # to maximum of 2 * io_size and 128MB + MIN_PREFETCH_SIZE = 128 * 1024 * 1024 + + def __init__( + self, + fetcher, + size: int, + concurrency: int, + queue: asyncio.Queue, + wakeup_event: asyncio.Event, + get_user_offset, + get_io_size, + get_sequential_streak, + on_error, + user_max_prefetch_size=None, + ): + """Initializes the background producer. + + Args: + fetcher (Callable): A coroutine function to fetch bytes from a remote source. + size (int): Total size of the file being fetched. + concurrency (int): Maximum number of concurrent fetch tasks. + queue (asyncio.Queue): The shared queue to push download tasks into. + wakeup_event (asyncio.Event): Event used to wake the producer from an idle state. + get_user_offset (Callable): Function returning the user's current read offset. + get_io_size (Callable): Function returning the adaptive IO size. + get_sequential_streak (Callable): Function returning the current sequential read streak. + on_error (Callable): Callback triggered when a background error occurs. + user_max_prefetch_size (int, optional): A hard limit for prefetch size overrides. + """ + logger.debug( + "Initializing PrefetchProducer: size=%d, concurrency=%d, user_max_prefetch_size=%s", + size, + concurrency, + user_max_prefetch_size, + ) + self.fetcher = fetcher + self.size = size + self.concurrency = concurrency + self.queue = queue + self.wakeup_event = wakeup_event + + self.get_user_offset = get_user_offset + self.get_io_size = get_io_size + self.get_sequential_streak = get_sequential_streak + self.on_error = on_error + self._user_max_prefetch_size = user_max_prefetch_size + + self.current_offset = 0 + self.is_stopped = False + self._active_tasks = set() + self._producer_task = None + + @property + def max_prefetch_size(self) -> int: + """Calculates the maximum prefetch size based on user intent or io size. + + Returns: + int: The maximum number of bytes to prefetch ahead. + """ + if self._user_max_prefetch_size is not None: + return self._user_max_prefetch_size + return max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE) + + def start(self): + """Starts the background producer loop. + + This clears any previous wakeup events and spawns the main loop task. + """ + logger.info("Starting PrefetchProducer loop.") + self.is_stopped = False + self.wakeup_event.clear() + self._producer_task = asyncio.create_task(self._loop()) + + async def stop(self): + """Cancels all active fetch tasks and shuts down the producer loop. + + This method ensures the queue is flushed and waits for cancelled + tasks to finish cleaning up. + """ + logger.info( + "Stopping PrefetchProducer. Active fetch tasks: %d", len(self._active_tasks) + ) + self.is_stopped = True + self.wakeup_event.set() + + tasks_to_wait = [] + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + tasks_to_wait.append(self._producer_task) + + for task in list(self._active_tasks): + if not task.done(): + tasks_to_wait.append(task) + self._active_tasks.clear() + + # Clear out any leftover items in the queue + cleared_items = 0 + while not self.queue.empty(): + try: + item = self.queue.get_nowait() + if ( + isinstance(item, asyncio.Task) + and item.done() + and not item.cancelled() + ): + item.exception() + cleared_items += 1 + except asyncio.QueueEmpty: + break + + if cleared_items > 0: + logger.debug( + "Cleared %d leftover items from the queue during stop.", cleared_items + ) + + if tasks_to_wait: + logger.debug( + "Waiting for %d cancelled tasks to finish their teardown.", + len(tasks_to_wait), + ) + await asyncio.gather(*tasks_to_wait, return_exceptions=True) + + async def restart(self, new_offset: int): + """Stops current tasks and restarts the background loop at a new byte offset. + + Args: + new_offset (int): The new byte position to start prefetching from. + """ + logger.info("Restarting PrefetchProducer at new offset: %d", new_offset) + await self.stop() + self.current_offset = new_offset + self.start() + + async def _loop(self): + """The main background loop that calculates sizes and spawns fetch tasks.""" + logger.debug("PrefetchProducer internal loop is now running.") + try: + while not self.is_stopped: + await self.wakeup_event.wait() + self.wakeup_event.clear() + + if self.is_stopped: + break + + io_size = self.get_io_size() + streak = self.get_sequential_streak() + prefetch_size = min((streak + 1) * io_size, self.max_prefetch_size) + + logger.debug( + "Producer awake. Current offset: %d, User offset: %d, Prefetch size: %d", + self.current_offset, + self.get_user_offset(), + prefetch_size, + ) + + while ( + not self.is_stopped + and (self.current_offset - self.get_user_offset()) < prefetch_size + and self.current_offset < self.size + ): + user_offset = self.get_user_offset() + space_remaining = self.size - self.current_offset + prefetch_space_available = prefetch_size - ( + self.current_offset - user_offset + ) + + if ( + space_remaining >= io_size + and prefetch_space_available < io_size + ): + break + + if prefetch_size >= self.MIN_CHUNK_SIZE: + if prefetch_space_available >= self.MIN_CHUNK_SIZE: + actual_size = min( + max(self.MIN_CHUNK_SIZE, io_size), space_remaining + ) + else: + break + else: + actual_size = min(io_size, space_remaining) + + if streak < 2: + sfactor = self.concurrency + else: + sfactor = min( + self.concurrency, + max( + 1, + actual_size * self.concurrency // (prefetch_size or 1), + ), + ) + + logger.debug( + "Spawning fetch task. Offset: %d, Size: %d, Split Factor: %d", + self.current_offset, + actual_size, + sfactor, + ) + + download_task = asyncio.create_task( + self.fetcher( + self.current_offset, actual_size, split_factor=sfactor + ) + ) + self._active_tasks.add(download_task) + download_task.add_done_callback(self._active_tasks.discard) + + await self.queue.put(download_task) + self.current_offset += actual_size + + except asyncio.CancelledError: + logger.debug("PrefetchProducer loop was cancelled.") + pass + except Exception as e: + logger.error( + "PrefetchProducer loop encountered an unexpected error: %s", + e, + exc_info=True, + ) + self.is_stopped = True + self.on_error(e) + await self.queue.put(e) + + +class PrefetchConsumer: + """Consumes prefetched chunks from the queue and manages byte slicing. + + This class pulls data out of the shared queue and slices it into the + exact byte sizes requested by the user. It also manages the local block buffer. + """ + + def __init__( + self, + queue: asyncio.Queue, + wakeup_event: asyncio.Event, + is_producer_stopped, + on_error, + ): + """Initializes the consumer. + + Args: + queue (asyncio.Queue): The shared queue containing fetch tasks. + wakeup_event (asyncio.Event): Event used to wake the producer when more data is needed. + is_producer_stopped (Callable): Function returning whether the producer has been halted. + on_error (Callable): Callback triggered when a fetch error is encountered. + """ + logger.debug("Initializing PrefetchConsumer.") + self.queue = queue + self.wakeup_event = wakeup_event + self.is_producer_stopped = is_producer_stopped + self.on_error = on_error + self.sequential_streak = 0 + self.offset = 0 + self._current_block = b"" + self._current_block_idx = 0 + + def seek(self, new_offset: int): + """Clears the buffer and resets the internal offset for a hard seek. + + Args: + new_offset (int): The byte position the consumer is jumping to. + """ + logger.info( + "Consumer executing hard seek to offset %d. Clearing internal buffer.", + new_offset, + ) + self.offset = new_offset + self.sequential_streak = 0 + self._current_block = b"" + self._current_block_idx = 0 + + def clear_buffer(self): + """Discards the local byte buffer. Useful during shutdown or resets.""" + logger.debug("Consumer local block buffer cleared.") + self._current_block = b"" + self._current_block_idx = 0 + + async def _advance(self, size: int, save_data: bool) -> list[bytes]: + """Internal method to advance the offset and optionally extract data. + + Handles queue exhaustion, producer wakeups, and streak tracking. + """ + if size <= 0: + return [] + + chunks = [] + processed = 0 + + while processed < size: + available = len(self._current_block) - self._current_block_idx + + if not available: + if self.is_producer_stopped() and self.queue.empty(): + logger.debug("Consumer reached EOF.") + break + + if self.queue.empty(): + logger.debug("Queue is empty. Waking up producer.") + self.wakeup_event.set() + + task = await self.queue.get() + + if isinstance(task, Exception): + logger.error("Consumer retrieved an exception: %s", task) + self.on_error(task) + raise task + + try: + block = await task + + self.sequential_streak += 1 + if self.sequential_streak >= 2: + self.wakeup_event.set() + + self._current_block = block + self._current_block_idx = 0 + available = len(self._current_block) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("Consumer caught an error: %s", e, exc_info=True) + self.on_error(e) + raise e + + if not self._current_block: + break + + needed = size - processed + take = min(needed, available) + + if save_data: + if take == len(self._current_block) and self._current_block_idx == 0: + chunk = self._current_block + else: + # Native Python slicing was GIL bound in my experiments. + chunk = await asyncio.to_thread( + _fast_slice, self._current_block, self._current_block_idx, take + ) + chunks.append(chunk) + + self._current_block_idx += take + processed += take + self.offset += take + + return chunks + + async def consume(self, size: int) -> bytes: + """Pulls exactly 'size' bytes from the local block or the task queue. + + If the local block is exhausted, this will wait on the queue for the next + available chunk of data. + + Args: + size (int): The exact number of bytes to retrieve. + + Returns: + bytes: The requested bytes. This may be shorter than 'size' if EOF is reached. + + Raises: + Exception: Re-raises any exceptions encountered by the producer fetch tasks. + """ + if size <= 0: + return b"" + + chunks = await self._advance(size, save_data=True) + + if not chunks: + return b"" + + if len(chunks) == 1: + return chunks[0] + + return b"".join(chunks) + + async def skip(self, size: int) -> None: + """Advances the consumer offset without allocating memory.""" + await self._advance(size, save_data=False) + + +class BackgroundPrefetcher: + """Orchestrator that manages reading behavior and coordinates background work. + + This acts as the main public interface for the file reader. It tracks the + user's reading history, routes seek operations, and links the producer's + network tasks with the consumer's data slicing logic. + """ + + def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None): + """Initializes the background prefetcher. + + Args: + fetcher (Callable): A coroutine of the form `f(start, end)` which gets bytes from the remote. + size (int): Total byte size of the file being read. + concurrency (int): Number of concurrent network requests to use for large chunks. + max_prefetch_size (int, optional): Maximum bytes to prefetch ahead of the current user offset. + + Raises: + ValueError: If max_prefetch_size is provided but is not a positive integer. + """ + logger.info( + "Starting BackgroundPrefetcher. Size: %d, Concurrency: %d, Max Prefetch: %s", + size, + concurrency, + max_prefetch_size, + ) + self.size = size + self.concurrency = concurrency + + if max_prefetch_size is not None and max_prefetch_size <= 0: + logger.error("Invalid max_prefetch_size provided: %s", max_prefetch_size) + raise ValueError( + "max_prefetch_size should be a positive integer to use adaptive prefetching!" + ) + + self.loop = fsspec.asyn.get_loop() + self._lock = threading.Lock() + self._error = None + self.is_stopped = False + self.queue = asyncio.Queue() + self.wakeup_event = asyncio.Event() + self.user_offset = 0 + self.read_tracker = RunningAverageTracker(maxlen=10) + + self.consumer = PrefetchConsumer( + queue=self.queue, + wakeup_event=self.wakeup_event, + is_producer_stopped=self._is_producer_stopped, + on_error=self._set_error, + ) + + self.producer = PrefetchProducer( + fetcher=fetcher, + size=self.size, + concurrency=self.concurrency, + queue=self.queue, + wakeup_event=self.wakeup_event, + get_user_offset=lambda: self.consumer.offset, + get_io_size=self._get_adaptive_io_size, + get_sequential_streak=lambda: self.consumer.sequential_streak, + on_error=self._set_error, + user_max_prefetch_size=max_prefetch_size, + ) + + async def _start(): + self.producer.start() + + fsspec.asyn.sync(self.loop, _start) + logger.debug("BackgroundPrefetcher initialization complete.") + + def __enter__(self): + """Context manager entry point.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit point. Ensures the prefetcher is cleanly closed.""" + self.close() + + def _get_adaptive_io_size(self) -> int: + return self.read_tracker.average + + def _is_producer_stopped(self) -> bool: + return self.producer.is_stopped if hasattr(self, "producer") else True + + def _set_error(self, e: Exception): + logger.error("Global error state set in BackgroundPrefetcher: %s", e) + self._error = e + + async def _restart_producer(self, new_offset: int): + logger.debug( + "Handling seek request. Restarting producer at offset: %d", new_offset + ) + self._error = None + await self.producer.restart(new_offset) + self.consumer.seek(new_offset) + self.read_tracker.clear() + + async def _async_fetch(self, start, end): + logger.debug("Executing _async_fetch for range %d - %d.", start, end) + + if start != self.user_offset: + if self.user_offset < start <= self.producer.current_offset: + logger.info( + "Soft seek detected. Skipping ahead from %d to %d.", + self.user_offset, + start, + ) + skip_amount = start - self.user_offset + await self.consumer.skip(skip_amount) + self.user_offset = start + else: + logger.info( + "Hard seek detected. Moving user offset from %d to %d.", + self.user_offset, + start, + ) + self.user_offset = start + await self._restart_producer(start) + + requested_size = end - start + self.read_tracker.add(requested_size) + + chunk = await self.consumer.consume(requested_size) + self.user_offset += len(chunk) + + logger.debug("Completed _async_fetch. Returned %d bytes.", len(chunk)) + return chunk + + def _fetch(self, start: int | None, end: int | None) -> bytes: + if start is None: + start = 0 + if end is None: + end = self.size + + end = min(end, self.size) + logger.debug( + "Synchronous _fetch called for bounds start=%s, end=%s.", start, end + ) + + if start >= self.size or start >= end: + logger.warning( + "Invalid bounds or EOF reached in _fetch. Start: %d, End: %d, Size: %d", + start, + end, + self.size, + ) + return b"" + + with self._lock: + if self._error: + logger.error("Cannot fetch data: instance has an active error state.") + raise self._error + + if self.is_stopped: + logger.error( + "Cannot fetch data: BackgroundPrefetcher is stopped or closed." + ) + raise RuntimeError( + "The file instance has been closed. This can occur if a close operation " + "is executed concurrently while a read operation is still in progress." + ) + + try: + result = fsspec.asyn.sync(self.loop, self._async_fetch, start, end) + except Exception as e: + logger.error( + "Exception raised during synchronous fetch: %s", e, exc_info=True + ) + self.is_stopped = True + self._error = e + raise + + if self.is_stopped: + logger.error("Instance was stopped mid-fetch operation.") + raise RuntimeError( + "The file instance has been closed. This can occur if a close operation " + "is executed concurrently while a read operation is still in progress." + ) + + return result + + def close(self): + """Safely shuts down the prefetcher. + + This cancels all background network tasks and blocks until everything + is completely cleaned up. It also clears the internal consumer buffer. + """ + logger.info("Closing BackgroundPrefetcher and cleaning up resources.") + if self.is_stopped: + logger.debug( + "BackgroundPrefetcher is already stopped. Skipping close operation." + ) + return + + self.is_stopped = True + with self._lock: + fsspec.asyn.sync(self.loop, self.producer.stop) + self.consumer.clear_buffer() + logger.info("BackgroundPrefetcher closed successfully.") diff --git a/gcsfs/tests/test_core.py b/gcsfs/tests/test_core.py index 285ee03f..99387e8d 100644 --- a/gcsfs/tests/test_core.py +++ b/gcsfs/tests/test_core.py @@ -1,3 +1,4 @@ +import concurrent.futures import io import os import uuid @@ -8,6 +9,7 @@ from urllib.parse import parse_qs, unquote, urlparse from uuid import uuid4 +import fsspec.asyn import fsspec.core import pytest import requests @@ -1921,3 +1923,178 @@ def test_mv_file_raises_error_for_specific_generation(gcs): gcs.mv_file(src, dest) finally: gcs.version_aware = original_version_aware + + +def test_cat_file_routing_and_thresholds(gcs): + fn = f"{TEST_BUCKET}/core_routing.txt" + # Create an 8MB file + data = os.urandom(8 * 1024 * 1024) + gcs.pipe(fn, data) + + # 1. Concurrency = 1 (Should route to sequential) + with mock.patch.object( + gcs, "_cat_file_sequential", wraps=gcs._cat_file_sequential + ) as mock_seq: + with mock.patch.object( + gcs, "_cat_file_concurrent", wraps=gcs._cat_file_concurrent + ) as mock_conc: + res = fsspec.asyn.sync( + gcs.loop, gcs._cat_file, fn, start=0, end=1024, concurrency=1 + ) + assert res == data[:1024] + assert mock_seq.call_count == 1 + assert mock_conc.call_count == 0 + + # 2. Concurrency = 4, but read size (1MB) is < MIN_CHUNK_SIZE_FOR_CONCURRENCY (5MB) + with mock.patch.object( + gcs, "_cat_file_sequential", wraps=gcs._cat_file_sequential + ) as mock_seq: + with mock.patch.object( + gcs, "_cat_file_concurrent", wraps=gcs._cat_file_concurrent + ) as mock_conc: + res = fsspec.asyn.sync( + gcs.loop, gcs._cat_file, fn, start=0, end=1024 * 1024, concurrency=4 + ) + assert res == data[: 1024 * 1024] + # It hits the concurrent wrapper, but bails out to sequential internally + assert mock_conc.call_count == 1 + assert mock_seq.call_count == 1 + + # 3. Concurrency = 4, and read size (8MB) >= MIN_CHUNK_SIZE_FOR_CONCURRENCY (5MB) + with mock.patch.object( + gcs, "_cat_file_sequential", wraps=gcs._cat_file_sequential + ) as mock_seq: + res = fsspec.asyn.sync( + gcs.loop, gcs._cat_file, fn, start=0, end=8 * 1024 * 1024, concurrency=4 + ) + assert res == data + # Should call sequential 4 times (once for each concurrent chunk) + assert mock_seq.call_count == 4 + + +def test_cat_file_concurrent_data_integrity(gcs): + fn = f"{TEST_BUCKET}/core_integrity.txt" + file_size = 20 * 1024 * 1024 # 20MB + data = os.urandom(file_size) + gcs.pipe(fn, data) + + res = fsspec.asyn.sync( + gcs.loop, gcs._cat_file_concurrent, fn, start=0, end=file_size, concurrency=7 + ) + assert len(res) == file_size + assert res == data + + +def test_cat_file_concurrent_exception_cancellation(gcs): + fn = f"{TEST_BUCKET}/core_exception.txt" + data = b"0123456789" * 6000000 # ~6MB + gcs.pipe(fn, data) + + original_seq = gcs._cat_file_sequential + + async def mock_fail_seq(path, start, end, **kwargs): + if start > 0: # Force failure on the 2nd chunk + raise OSError("Simulated HTTP Timeout") + return await original_seq(path, start, end, **kwargs) + + with mock.patch.object(gcs, "_cat_file_sequential", side_effect=mock_fail_seq): + with pytest.raises(OSError, match="Simulated HTTP Timeout"): + fsspec.asyn.sync( + gcs.loop, + gcs._cat_file_concurrent, + fn, + start=0, + end=len(data), + concurrency=4, + ) + + +def test_gcsfile_prefetch_disabled_fallback(gcs): + """Verify that omitting the flag entirely skips the prefetcher initialization.""" + fn = f"{TEST_BUCKET}/no_prefetch.txt" + gcs.pipe(fn, b"HelloWorld") + + with gcs.open(fn, "rb", use_experimental_adaptive_prefetching=False) as f: + assert getattr(f, "_prefetch_engine", None) is None + assert f.read() == b"HelloWorld" + + +def test_gcsfile_prefetch_sequential_integrity(gcs): + fn = f"{TEST_BUCKET}/integrated_seq.txt" + file_size = 10 * 1024 * 1024 + data = os.urandom(file_size) + gcs.pipe(fn, data) + + with gcs.open( + fn, "rb", use_experimental_adaptive_prefetching=True, block_size=2 * 1024 * 1024 + ) as f: + assert f._prefetch_engine is not None + + chunks = [] + while True: + chunk = f.read(1024 * 1024) # Read 1MB at a time + if not chunk: + break + chunks.append(chunk) + + assert b"".join(chunks) == data + + +def test_gcsfile_prefetch_random_seek_integrity(gcs): + fn = f"{TEST_BUCKET}/integrated_random.txt" + file_size = 5 * 1024 * 1024 + data = os.urandom(file_size) + gcs.pipe(fn, data) + + import random + + random.seed(42) + + with gcs.open( + fn, "rb", use_experimental_adaptive_prefetching=True, block_size=1024 * 1024 + ) as f: + for _ in range(50): + start = random.randint(0, file_size - 1000) + length = random.randint(1, 1000) + + f.seek(start) + chunk = f.read(length) + + assert len(chunk) == length + assert chunk == data[start : start + length] + + +def test_gcsfile_multithreaded_read_integrity(gcs): + fn = f"{TEST_BUCKET}/integrated_mt.txt" + file_size = 15 * 1024 * 1024 + data = os.urandom(file_size) + gcs.pipe(fn, data) + + with gcs.open( + fn, "rb", use_experimental_adaptive_prefetching=True, block_size=2 * 1024 * 1024 + ) as f: + + def thread_worker(start, size): + return f._fetch_range(start, start + size) + + chunk_size = 3 * 1024 * 1024 + futures = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + for i in range(5): + start_offset = i * chunk_size + futures.append(executor.submit(thread_worker, start_offset, chunk_size)) + + results = [fut.result() for fut in futures] + stitched_data = b"".join(results) + + assert len(stitched_data) == file_size + assert stitched_data == data + + +def test_gcsfile_not_satisfiable_range(gcs): + fn = f"{TEST_BUCKET}/integrated_eof.txt" + gcs.pipe(fn, b"12345") + + with gcs.open(fn, "rb", use_experimental_adaptive_prefetching=True) as f: + res = f._fetch_range(100, 200) + assert res == b"" diff --git a/gcsfs/tests/test_prefetcher.py b/gcsfs/tests/test_prefetcher.py new file mode 100644 index 00000000..48aa7f6e --- /dev/null +++ b/gcsfs/tests/test_prefetcher.py @@ -0,0 +1,555 @@ +import asyncio +from unittest import mock + +import fsspec.asyn +import pytest + +from gcsfs.prefetcher import BackgroundPrefetcher, RunningAverageTracker, _fast_slice + + +class MockFetcher: + def __init__(self, data, fail_at_call=None, hang_at_call=None): + self.data = data + self.calls = [] + self.fail_at_call = fail_at_call + self.hang_at_call = hang_at_call + self.call_count = 0 + + async def __call__(self, start, size, split_factor=1): + self.call_count += 1 + self.calls.append({"start": start, "size": size, "split_factor": split_factor}) + + await asyncio.sleep(0.001) + + if self.hang_at_call is not None and self.call_count >= self.hang_at_call: + await asyncio.sleep(1000) + + if self.fail_at_call is not None and self.call_count >= self.fail_at_call: + raise OSError("Simulated Network Timeout") + + return self.data[start : start + size] + + +def test_fast_slice_direct(): + src = b"0123456789" + assert _fast_slice(src, 2, 4) == b"2345" + assert _fast_slice(src, 5, 0) == b"" + assert _fast_slice(src, 0, 10) == b"0123456789" + + +def test_running_average_tracker(): + tracker = RunningAverageTracker(maxlen=3) + assert tracker.average == 1024 * 1024 # Default 1MB fallback + + tracker.add(512) + tracker.add(512) + assert tracker.average == 512 + + tracker.add(2048) + assert tracker.average == 1024 # (512 + 512 + 2048) // 3 + + tracker.clear() + assert tracker.average == 1024 * 1024 + + +def test_max_prefetch_size_property(): + bp1 = BackgroundPrefetcher(fetcher=MockFetcher(b""), size=10000, concurrency=4) + assert bp1.producer.max_prefetch_size == bp1.producer.MIN_PREFETCH_SIZE + bp1.close() + + bp2 = BackgroundPrefetcher(fetcher=MockFetcher(b""), size=1000000000, concurrency=4) + # Give it a history so it calculates 2x the io_size + bp2.read_tracker.add(100 * 1024 * 1024) + assert bp2.producer.max_prefetch_size == 200 * 1024 * 1024 + bp2.close() + + +def test_sequential_read_spanning_blocks(): + data = b"A" * 100 + b"B" * 100 + b"C" * 100 + fetcher = MockFetcher(data) + bp = BackgroundPrefetcher(fetcher=fetcher, size=300, concurrency=4) + bp.read_tracker.add(100) # Seed the adaptive tracker + + assert bp._fetch(0, 100) == b"A" * 100 + assert bp._fetch(100, 150) == b"B" * 50 + assert bp.consumer._current_block_idx == 50 + assert bp._fetch(150, 250) == b"B" * 50 + b"C" * 50 + assert bp._fetch(250, 300) == b"C" * 50 + assert bp._fetch(300, 310) == b"" + + bp.close() + + +def test_fetch_default_args_and_out_of_bounds(): + fetcher = MockFetcher(b"12345") + bp = BackgroundPrefetcher(fetcher=fetcher, size=5, concurrency=4) + + assert bp._fetch(None, None) == b"12345" + assert bp._fetch(None, 2) == b"12" + assert bp._fetch(5, 10) == b"" + assert bp._fetch(10, 20) == b"" + assert bp._fetch(2, 2) == b"" + assert bp._fetch(4, 2) == b"" + + bp.close() + + +def test_seek_logic(): + data = b"0123456789" * 10 + fetcher = MockFetcher(data) + bp = BackgroundPrefetcher(fetcher=fetcher, size=100, concurrency=4) + + assert bp._fetch(0, 10) == data[0:10] + assert bp._fetch(10, 20) == data[10:20] + assert bp.user_offset == 20 + assert bp._fetch(50, 60) == data[50:60] + assert bp.user_offset == 60 + assert bp._fetch(10, 20) == data[10:20] + assert bp.user_offset == 20 + + bp.close() + + +def test_exception_placed_in_queue(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + async def inject_error(): + await bp.queue.put(ValueError("Injected Producer Error")) + + fsspec.asyn.sync(bp.loop, inject_error) + + with pytest.raises(ValueError, match="Injected Producer Error"): + bp._fetch(0, 50) + + assert isinstance(bp._error, ValueError) + bp.close() + + +def test_producer_concurrency_streak_and_min_chunk(): + data = b"X" * 1000 + fetcher = MockFetcher(data) + + bp = BackgroundPrefetcher(fetcher=fetcher, size=1000, concurrency=4) + bp.read_tracker.add(50) + + # Temporarily lower chunk limit for test + original_min_chunk = bp.producer.MIN_CHUNK_SIZE + bp.producer.MIN_CHUNK_SIZE = 10 + + bp._fetch(0, 50) + bp._fetch(50, 100) + bp._fetch(100, 150) + + fsspec.asyn.sync(bp.loop, asyncio.sleep, 0.1) + + split_factors = [call["split_factor"] for call in fetcher.calls] + assert split_factors[0] == 4 + assert max(split_factors) > 1 + assert max(split_factors) <= 4 + + bp.producer.MIN_CHUNK_SIZE = original_min_chunk + bp.close() + + +def test_producer_loop_space_constraints(): + data = b"Y" * 100 + fetcher = MockFetcher(data) + + bp = BackgroundPrefetcher(fetcher=fetcher, size=100, concurrency=4) + bp.read_tracker.add(60) + + original_min_chunk = bp.producer.MIN_CHUNK_SIZE + bp.producer.MIN_CHUNK_SIZE = 200 + + assert bp._fetch(0, 10) == b"Y" * 10 + + fsspec.asyn.sync(bp.loop, asyncio.sleep, 0.1) + sizes = [call["size"] for call in fetcher.calls] + assert all(s <= 100 for s in sizes) + + bp.producer.MIN_CHUNK_SIZE = original_min_chunk + bp.close() + + +def test_producer_error_propagation(): + fetcher = MockFetcher(b"A" * 1000, fail_at_call=3) + bp = BackgroundPrefetcher(fetcher=fetcher, size=1000, concurrency=4) + bp.read_tracker.add(100) + + assert bp._fetch(0, 100) == b"A" * 100 + + with pytest.raises(OSError, match="Simulated Network Timeout"): + bp._fetch(100, 500) + + assert bp.is_stopped is True + bp.close() + + +def test_read_after_close_or_error(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + bp.close() + + assert bp.is_stopped is True + with pytest.raises(RuntimeError, match="The file instance has been closed"): + bp._fetch(0, 10) + + bp2 = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + bp2._error = ValueError("Pre-existing error") + with pytest.raises(ValueError, match="Pre-existing error"): + bp2._fetch(0, 10) + bp2.close() + + +def test_empty_queue_when_stopped(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 500), size=500, concurrency=4) + bp.is_stopped = True + + with pytest.raises(RuntimeError, match="The file instance has been closed"): + bp._fetch(0, 100) + + bp.close() + + +def test_cancel_all_tasks_cleans_queue_with_exceptions(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + async def inject_task(): + async def dummy_exception_task(): + raise ValueError("Hidden error") + + task = asyncio.create_task(dummy_exception_task()) + await bp.queue.put(task) + await asyncio.sleep(0.05) + + fsspec.asyn.sync(bp.loop, inject_task) + bp.close() + assert bp.queue.empty() + + +def test_cleanup_cancels_active_tasks(): + bp = BackgroundPrefetcher( + fetcher=MockFetcher(b"Z" * 1000), size=1000, concurrency=4 + ) + + async def inject_task(): + async def dummy_task(): + await asyncio.sleep(3) + + task = asyncio.create_task(dummy_task()) + bp.producer._active_tasks.add(task) + + fsspec.asyn.sync(bp.loop, inject_task) + + assert len(bp.producer._active_tasks) > 0 + assert bp.is_stopped is False + + bp.close() + + assert bp.is_stopped is True + assert len(bp.producer._active_tasks) == 0 + + +def test_read_task_cancellation(): + bp = BackgroundPrefetcher( + fetcher=MockFetcher(b"X" * 1000), size=1000, concurrency=4 + ) + + async def inject_and_read(): + bp.is_stopped = True + while not bp.queue.empty(): + bp.queue.get_nowait() + + cancel_task = asyncio.create_task(asyncio.sleep(10)) + cancel_task.cancel() + await bp.queue.put(cancel_task) + + with pytest.raises(asyncio.CancelledError): + await bp.consumer.consume(10) + + fsspec.asyn.sync(bp.loop, inject_and_read) + bp.close() + + +def test_async_fetch_exception_trapping(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + def bad_sync(*args, **kwargs): + raise RuntimeError("Simulated sync crash") + + with mock.patch("fsspec.asyn.sync", side_effect=bad_sync): + with pytest.raises(RuntimeError, match="Simulated sync crash"): + bp._fetch(0, 10) + + assert bp.is_stopped is True + assert isinstance(bp._error, RuntimeError) + bp.close() + + +def test_read_past_eof_internal(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 50), size=50, concurrency=4) + bp.user_offset = 50 + res = bp._fetch(50, 60) + assert res == b"" + bp.close() + + +def test_fetch_with_exact_block_matches(): + data = b"X" * 100 + bp = BackgroundPrefetcher(fetcher=MockFetcher(data), size=100, concurrency=4) + bp.read_tracker.add(50) + + assert bp._fetch(0, 50) == b"X" * 50 + assert bp.consumer._current_block_idx == 50 + assert bp._fetch(50, 100) == b"X" * 50 + + bp.close() + + +def test_queue_empty_race_condition(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + async def inject(): + bp.queue.put_nowait(asyncio.create_task(asyncio.sleep(0))) + with mock.patch.object(bp.queue, "get_nowait", side_effect=asyncio.QueueEmpty): + await bp.producer.stop() + + fsspec.asyn.sync(bp.loop, inject) + bp.close() + + +def test_producer_space_remaining_break(): + bp = BackgroundPrefetcher( + fetcher=MockFetcher(b"X" * 1000), + size=1000, + concurrency=4, + max_prefetch_size=150, + ) + bp._fetch(0, 10) + fsspec.asyn.sync(bp.loop, asyncio.sleep, 0.1) + bp.close() + + +def test_producer_min_chunk_logic(): + bp1 = BackgroundPrefetcher( + fetcher=MockFetcher(b"X" * 1000), + size=1000, + concurrency=4, + max_prefetch_size=300, + ) + bp1.producer.MIN_CHUNK_SIZE = 100 + + fsspec.asyn.sync(bp1.loop, asyncio.sleep, 0.1) + bp1.close() + + bp2 = BackgroundPrefetcher( + fetcher=MockFetcher(b"X" * 1000), + size=1000, + concurrency=4, + max_prefetch_size=150, + ) + bp2.producer.MIN_CHUNK_SIZE = 100 + fsspec.asyn.sync(bp2.loop, asyncio.sleep, 0.1) + bp2.close() + + +def test_producer_loop_exception(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b""), size=100, concurrency=4) + error_object = ValueError("Producer crash") + bp.producer.get_io_size = mock.Mock(side_effect=error_object) + + with pytest.raises(ValueError, match="Producer crash"): + bp._fetch(0, 10) + + assert bp.is_stopped is True + assert bp._error == error_object + + with pytest.raises(ValueError, match="Producer crash"): + bp._fetch(0, 10) + bp.close() + + +def test_seek_same_offset(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b""), size=100, concurrency=4) + fsspec.asyn.sync(bp.loop, bp._async_fetch, 0, 10) + bp.close() + + +def test_read_history_maxlen(): + bp = BackgroundPrefetcher( + fetcher=MockFetcher(b"X" * 2000), size=2000, concurrency=4 + ) + for i in range(12): + bp._fetch(i * 10, (i + 1) * 10) + assert len(bp.read_tracker._history) == 10 + bp.close() + + +def test_fast_slice_branch(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 200), size=200, concurrency=4) + assert bp._fetch(0, 10) == b"X" * 10 + assert bp._fetch(10, 20) == b"X" * 10 + bp.close() + + +def test_fetch_stopped_during_execution(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + async def fake_async_fetch(start, end): + bp.is_stopped = True + return b"fake" + + with mock.patch.object(bp, "_async_fetch", new=fake_async_fetch): + with pytest.raises(RuntimeError, match="The file instance has been closed"): + bp._fetch(0, 10) + bp.close() + + +def test_producer_space_remaining_break_exact(): + fetcher = MockFetcher(b"X" * 1000) + bp = BackgroundPrefetcher( + fetcher=fetcher, size=1000, concurrency=4, max_prefetch_size=150 + ) + bp.read_tracker.add(100) + + async def trigger_loop(): + bp.producer.current_offset = 100 + bp.consumer.offset = 0 + bp.consumer.sequential_streak = 5 + + bp.wakeup_event.set() + await asyncio.sleep(0.05) + + fsspec.asyn.sync(bp.loop, trigger_loop) + assert fetcher.call_count == 0 + bp.close() + + +def test_async_fetch_not_block_break(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b""), size=100, concurrency=4) + + async def fake_consume(size): + return b"" + + bp.consumer.consume = fake_consume + bp.user_offset = 0 + + res = fsspec.asyn.sync(bp.loop, bp._async_fetch, 0, 50) + assert res == b"" + bp.close() + + +def test_fetch_stopped_before_execution(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + bp.is_stopped = True + bp._error = None + + with pytest.raises(RuntimeError, match="The file instance has been closed"): + bp._fetch(0, 10) + bp.close() + + +def test_async_fetch_zero_copy_remainder(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X"), size=100, concurrency=4) + bp.consumer._current_block = b"ABCDE" + bp.consumer._current_block_idx = 0 + bp.user_offset = 0 + res = fsspec.asyn.sync(bp.loop, bp._async_fetch, 0, 5) + assert res == b"ABCDE" + assert bp.consumer._current_block_idx == 5 + bp.close() + + +def test_read_runtime_error_on_stopped_empty(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X"), size=100, concurrency=4) + bp.is_stopped = True + bp.producer.is_stopped = True + + while not bp.queue.empty(): + bp.queue.get_nowait() + + res = fsspec.asyn.sync(bp.loop, bp.consumer.consume, 10) + assert res == b"" + bp.close() + + +def test_init_invalid_max_prefetch_size(): + with pytest.raises( + ValueError, + match=r"max_prefetch_size should be a positive integer", + ): + BackgroundPrefetcher( + fetcher=MockFetcher(b""), size=1000, concurrency=4, max_prefetch_size=0 + ) + + +def test_init_valid_max_prefetch_size_edge_case(): + bp = BackgroundPrefetcher( + fetcher=MockFetcher(b""), size=1000, concurrency=4, max_prefetch_size=100 + ) + assert bp.producer._user_max_prefetch_size == 100 + bp.close() + + +def test_consumer_zero_size_checks(): + bp = BackgroundPrefetcher(fetcher=MockFetcher(b"X" * 100), size=100, concurrency=4) + + # 1. Test consume size <= 0 + res_consume_zero = fsspec.asyn.sync(bp.loop, bp.consumer.consume, 0) + assert res_consume_zero == b"" + res_consume_neg = fsspec.asyn.sync(bp.loop, bp.consumer.consume, -5) + assert res_consume_neg == b"" + + # 2. Test _advance size <= 0 directly + # (consume catches it early, so we call _advance directly to hit its internal check) + res_advance_zero = fsspec.asyn.sync( + bp.loop, bp.consumer._advance, 0, save_data=True + ) + assert res_advance_zero == [] + res_advance_neg = fsspec.asyn.sync( + bp.loop, bp.consumer._advance, -10, save_data=False + ) + assert res_advance_neg == [] + + bp.close() + + +def test_producer_min_chunk_inner_break(): + fetcher = MockFetcher(b"X" * 1000) + bp = BackgroundPrefetcher( + fetcher=fetcher, size=1000, concurrency=4, max_prefetch_size=400 + ) + + bp.read_tracker.add(100) + + original_min_chunk = bp.producer.MIN_CHUNK_SIZE + bp.producer.MIN_CHUNK_SIZE = 200 + + async def trigger_loop(): + bp.producer.current_offset = 250 + bp.consumer.offset = 0 + bp.consumer.sequential_streak = 3 # makes prefetch_size = (3+1) * 100 = 400 + bp.wakeup_event.set() + await asyncio.sleep(0.05) + + fsspec.asyn.sync(bp.loop, trigger_loop) + + assert fetcher.call_count == 0 + + bp.producer.MIN_CHUNK_SIZE = original_min_chunk + bp.close() + + +def test_producer_loop_break_on_stopped_after_wakeup(): + fetcher = MockFetcher(b"X" * 1000) + bp = BackgroundPrefetcher(fetcher=fetcher, size=1000, concurrency=4) + + async def trigger_stop_and_wake(): + bp.producer.is_stopped = True + bp.wakeup_event.set() + await asyncio.sleep(0.05) + + fsspec.asyn.sync(bp.loop, trigger_stop_and_wake) + + # Verify the producer gracefully exited without doing work + assert fetcher.call_count == 0 + bp.close() diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py index e7fd7dfb..768a68fc 100644 --- a/gcsfs/zb_hns_utils.py +++ b/gcsfs/zb_hns_utils.py @@ -11,6 +11,7 @@ ) MRD_MAX_RANGES = 1000 # MRD supports up to 1000 ranges per request +DEFAULT_CONCURRENCY = 4 logger = logging.getLogger("gcsfs")