diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b159fc6f..f9247e11 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,13 +35,22 @@ jobs: - name: install run: | pip install -e . - - name: Run tests + - name: Run Standard Tests run: | export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/gcsfs/tests/fake-secret.json pytest -vv -s \ --log-format="%(asctime)s %(levelname)s %(message)s" \ --log-date-format="%H:%M:%S" \ - gcsfs/ + gcsfs/ \ + --ignore=gcsfs/tests/test_extended_gcsfs.py + - name: Run Extended Tests + run: | + export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/gcsfs/tests/fake-secret.json + export GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT="true" + pytest -vv -s \ + --log-format="%(asctime)s %(levelname)s %(message)s" \ + --log-date-format="%H:%M:%S" \ + gcsfs/tests/test_extended_gcsfs.py lint: name: lint diff --git a/.isort.cfg b/.isort.cfg index 1fff95db..1eab763a 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,3 @@ [settings] +profile = black known_third_party = aiohttp,click,decorator,fsspec,fuse,google,google_auth_oauthlib,pytest,requests,setuptools diff --git a/environment_gcsfs.yaml b/environment_gcsfs.yaml index 4cd6ff8b..39db5120 100644 --- a/environment_gcsfs.yaml +++ b/environment_gcsfs.yaml @@ -13,9 +13,11 @@ dependencies: - google-auth-oauthlib - google-cloud-core - google-cloud-storage + - grpcio - pytest - pytest-timeout - pytest-asyncio + - pytest-subtests - requests - ujson - pip: diff --git a/gcsfs/__init__.py b/gcsfs/__init__.py index fffbca44..f22d6ba1 100644 --- a/gcsfs/__init__.py +++ b/gcsfs/__init__.py @@ -1,10 +1,29 @@ +import logging +import os + from ._version import get_versions +logger = logging.getLogger(__name__) __version__ = get_versions()["version"] del get_versions from .core import GCSFileSystem from .mapping import GCSMap +if os.getenv("GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT", "false").lower() in ("true", "1"): + try: + from .extended_gcsfs import ExtendedGcsFileSystem as GCSFileSystem + + logger.info( + "gcsfs experimental features enabled via GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT." + ) + except ImportError as e: + logger.warning( + f"GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT is set, but failed to import experimental features: {e}" + ) + # Fallback to core GCSFileSystem, do not register here + +# TODO: GCSMap still refers to the original GCSFileSystem. This will be +# addressed in a future update. __all__ = ["GCSFileSystem", "GCSMap"] from . import _version diff --git a/gcsfs/extended_gcsfs.py b/gcsfs/extended_gcsfs.py new file mode 100644 index 00000000..f4c643a4 --- /dev/null +++ b/gcsfs/extended_gcsfs.py @@ -0,0 +1,245 @@ +import logging +from enum import Enum + +from fsspec import asyn +from google.api_core import exceptions as api_exceptions +from google.api_core import gapic_v1 +from google.api_core.client_info import ClientInfo +from google.cloud import storage_control_v2 +from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + +from gcsfs import __version__ as version +from gcsfs import zb_hns_utils +from gcsfs.core import GCSFile, GCSFileSystem +from gcsfs.zonal_file import ZonalFile + +logger = logging.getLogger("gcsfs") + +USER_AGENT = "python-gcsfs" + + +class BucketType(Enum): + ZONAL_HIERARCHICAL = "ZONAL_HIERARCHICAL" + HIERARCHICAL = "HIERARCHICAL" + NON_HIERARCHICAL = "NON_HIERARCHICAL" + UNKNOWN = "UNKNOWN" + + +gcs_file_types = { + BucketType.ZONAL_HIERARCHICAL: ZonalFile, + BucketType.NON_HIERARCHICAL: GCSFile, + BucketType.HIERARCHICAL: GCSFile, + BucketType.UNKNOWN: GCSFile, +} + + +class ExtendedGcsFileSystem(GCSFileSystem): + """ + This class will be used when GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT env variable is set to true. + ExtendedGcsFileSystem is a subclass of GCSFileSystem that adds new logic for bucket types + including zonal and hierarchical. For buckets without special properties, it forwards requests + to the parent class GCSFileSystem for default processing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.grpc_client = None + self.storage_control_client = None + # initializing grpc and storage control client for Hierarchical and + # zonal bucket operations + self.grpc_client = asyn.sync(self.loop, self._create_grpc_client) + self._storage_control_client = asyn.sync( + self.loop, self._create_control_plane_client + ) + self._storage_layout_cache = {} + + async def _create_grpc_client(self): + if self.grpc_client is None: + return AsyncGrpcClient( + client_info=ClientInfo(user_agent=f"{USER_AGENT}/{version}"), + ).grpc_client + else: + return self.grpc_client + + async def _create_control_plane_client(self): + # Initialize the storage control plane client for bucket + # metadata operations + client_info = gapic_v1.client_info.ClientInfo( + user_agent=f"{USER_AGENT}/{version}" + ) + return storage_control_v2.StorageControlAsyncClient( + credentials=self.credentials.credentials, client_info=client_info + ) + + async def _lookup_bucket_type(self, bucket): + if bucket in self._storage_layout_cache: + return self._storage_layout_cache[bucket] + bucket_type = await self._get_bucket_type(bucket) + # Dont cache UNKNOWN type + if bucket_type == BucketType.UNKNOWN: + return BucketType.UNKNOWN + self._storage_layout_cache[bucket] = bucket_type + return self._storage_layout_cache[bucket] + + _sync_lookup_bucket_type = asyn.sync_wrapper(_lookup_bucket_type) + + async def _get_bucket_type(self, bucket): + try: + bucket_name_value = f"projects/_/buckets/{bucket}/storageLayout" + response = await self._storage_control_client.get_storage_layout( + name=bucket_name_value + ) + + if response.location_type == "zone": + return BucketType.ZONAL_HIERARCHICAL + else: + # This should be updated to include HNS in the future + return BucketType.NON_HIERARCHICAL + except api_exceptions.NotFound: + logger.warning(f"Error: Bucket {bucket} not found or you lack permissions.") + return BucketType.UNKNOWN + except Exception as e: + logger.error( + f"Could not determine bucket type for bucket name {bucket}: {e}" + ) + # Default to UNKNOWN in case bucket type is not obtained + return BucketType.UNKNOWN + + def _open( + self, + path, + mode="rb", + block_size=None, + cache_options=None, + acl=None, + consistency=None, + metadata=None, + autocommit=True, + fixed_key_metadata=None, + generation=None, + **kwargs, + ): + """ + Open a file. + """ + bucket, _, _ = self.split_path(path) + bucket_type = self._sync_lookup_bucket_type(bucket) + return gcs_file_types[bucket_type]( + self, + path, + mode, + block_size, + cache_options=cache_options, + consistency=consistency, + metadata=metadata, + acl=acl, + autocommit=autocommit, + fixed_key_metadata=fixed_key_metadata, + generation=generation, + **kwargs, + ) + + # Replacement method for _process_limits to support new params (offset and length) for MRD. + async def _process_limits_to_offset_and_length(self, path, start, end): + """ + Calculates the read offset and length from start and end parameters. + + Args: + path (str): The path to the file. + start (int | None): The starting byte position. + end (int | None): The ending byte position. + + Returns: + tuple: A tuple containing (offset, length). + + Raises: + ValueError: If the calculated range is invalid. + """ + size = None + + if start is None: + offset = 0 + elif start < 0: + size = (await self._info(path))["size"] if size is None else size + offset = size + start + else: + offset = start + + if end is None: + size = (await self._info(path))["size"] if size is None else size + effective_end = size + elif end < 0: + size = (await self._info(path))["size"] if size is None else size + effective_end = size + end + else: + effective_end = end + + if offset < 0: + raise ValueError(f"Calculated start offset ({offset}) cannot be negative.") + if effective_end < offset: + raise ValueError( + f"Calculated end position ({effective_end}) cannot be before start offset ({offset})." + ) + elif effective_end == offset: + length = 0 # Handle zero-length slice + else: + length = effective_end - offset # Normal case + size = (await self._info(path))["size"] if size is None else size + if effective_end > size: + length = max(0, size - offset) # Clamp and ensure non-negative + + return offset, length + + sync_process_limits_to_offset_and_length = asyn.sync_wrapper( + _process_limits_to_offset_and_length + ) + + async def _is_zonal_bucket(self, bucket): + bucket_type = await self._lookup_bucket_type(bucket) + return bucket_type == BucketType.ZONAL_HIERARCHICAL + + async def _cat_file(self, path, start=None, end=None, mrd=None, **kwargs): + """Fetch a file's contents as bytes, with an optimized path for Zonal buckets. + + This method overrides the parent `_cat_file` to read objects in Zonal buckets using gRPC. + + Args: + path (str): The full GCS path to the file (e.g., "bucket/object"). + start (int, optional): The starting byte position to read from. + end (int, optional): The ending byte position to read to. + mrd (AsyncMultiRangeDownloader, optional): An existing multi-range + downloader instance. If not provided, a new one will be created for Zonal buckets. + + Returns: + bytes: The content of the file or file range. + """ + mrd = kwargs.pop("mrd", None) + mrd_created = False + + # A new MRD is required when read is done directly by the + # GCSFilesystem class without creating a GCSFile object first. + if mrd is None: + bucket, object_name, generation = self.split_path(path) + # Fall back to default implementation if not a zonal bucket + if not await self._is_zonal_bucket(bucket): + return await super()._cat_file(path, start=start, end=end, **kwargs) + + mrd = await AsyncMultiRangeDownloader.create_mrd( + self.grpc_client, bucket, object_name, generation + ) + mrd_created = True + + offset, length = await self._process_limits_to_offset_and_length( + path, start, end + ) + try: + return await zb_hns_utils.download_range( + offset=offset, length=length, mrd=mrd + ) + finally: + # Explicit cleanup if we created the MRD + if mrd_created: + await mrd.close() diff --git a/gcsfs/tests/conftest.py b/gcsfs/tests/conftest.py index e1a73732..5d5e0179 100644 --- a/gcsfs/tests/conftest.py +++ b/gcsfs/tests/conftest.py @@ -1,7 +1,10 @@ +import logging import os import shlex import subprocess import time +from contextlib import nullcontext +from unittest.mock import patch import fsspec import pytest @@ -91,10 +94,9 @@ def docker_gcs(): def gcs_factory(docker_gcs): params["endpoint_url"] = docker_gcs - def factory(default_location=None): + def factory(**kwargs): GCSFileSystem.clear_instance_cache() - params["default_location"] = default_location - return fsspec.filesystem("gcs", **params) + return fsspec.filesystem("gcs", **params, **kwargs) return factory @@ -125,6 +127,51 @@ def gcs(gcs_factory, populate=True): pass +def _cleanup_gcs(gcs, is_real_gcs): + """Only remove the bucket/contents if we are NOT using the real GCS, logging a warning on failure.""" + if is_real_gcs: + return + try: + gcs.rm(TEST_BUCKET, recursive=True) + except Exception as e: + logging.warning(f"Failed to clean up GCS bucket {TEST_BUCKET}: {e}") + + +@pytest.fixture +def extended_gcsfs(gcs_factory, populate=True): + # Check if we are running against a real GCS endpoint + is_real_gcs = ( + os.environ.get("STORAGE_EMULATOR_HOST") == "https://storage.googleapis.com" + ) + + # Mock authentication if not using a real GCS endpoint, + # since grpc client in extended_gcsfs does not work with anon access + mock_authentication_manager = ( + patch("google.auth.default", return_value=(None, "fake-project")) + if not is_real_gcs + else nullcontext() + ) + + with mock_authentication_manager: + extended_gcsfs = gcs_factory() + try: + # Only create/delete/populate the bucket if we are NOT using the real GCS endpoint + if not is_real_gcs: + try: + extended_gcsfs.rm(TEST_BUCKET, recursive=True) + except FileNotFoundError: + pass + extended_gcsfs.mkdir(TEST_BUCKET) + if populate: + extended_gcsfs.pipe( + {TEST_BUCKET + "/" + k: v for k, v in allfiles.items()} + ) + extended_gcsfs.invalidate_cache() + yield extended_gcsfs + finally: + _cleanup_gcs(extended_gcsfs, is_real_gcs) + + @pytest.fixture def gcs_versioned(gcs_factory): gcs = gcs_factory() diff --git a/gcsfs/tests/test_extended_gcsfs.py b/gcsfs/tests/test_extended_gcsfs.py new file mode 100644 index 00000000..1ff2b653 --- /dev/null +++ b/gcsfs/tests/test_extended_gcsfs.py @@ -0,0 +1,310 @@ +import contextlib +import io +import os +from itertools import chain +from unittest import mock + +import pytest +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) +from google.cloud.storage.exceptions import DataCorruption + +from gcsfs.extended_gcsfs import BucketType +from gcsfs.tests.conftest import a, b, c, csv_files, files, text_files +from gcsfs.tests.settings import TEST_BUCKET + +file = "test/accounts.1.json" +file_path = f"{TEST_BUCKET}/{file}" +json_data = files[file] +lines = io.BytesIO(json_data).readlines() +file_size = len(json_data) + + +@pytest.fixture +def zonal_mocks(): + """A factory fixture for mocking Zonal bucket functionality.""" + + @contextlib.contextmanager + def _zonal_mocks_factory(file_data): + """Creates mocks for a given file content.""" + is_real_gcs = ( + os.environ.get("STORAGE_EMULATOR_HOST") == "https://storage.googleapis.com" + ) + if is_real_gcs: + yield None + return + patch_target_lookup_bucket_type = ( + "gcsfs.extended_gcsfs.ExtendedGcsFileSystem._lookup_bucket_type" + ) + patch_target_sync_lookup_bucket_type = ( + "gcsfs.extended_gcsfs.ExtendedGcsFileSystem._sync_lookup_bucket_type" + ) + patch_target_create_mrd = ( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader" + ".AsyncMultiRangeDownloader.create_mrd" + ) + patch_target_gcsfs_cat_file = "gcsfs.core.GCSFileSystem._cat_file" + + async def download_side_effect(read_requests, **kwargs): + if read_requests and len(read_requests) == 1: + param_offset, param_length, buffer_arg = read_requests[0] + if hasattr(buffer_arg, "write"): + buffer_arg.write( + file_data[param_offset : param_offset + param_length] + ) + return [mock.Mock(error=None)] + + mock_downloader = mock.Mock(spec=AsyncMultiRangeDownloader) + mock_downloader.download_ranges = mock.AsyncMock( + side_effect=download_side_effect + ) + + mock_create_mrd = mock.AsyncMock(return_value=mock_downloader) + with ( + mock.patch( + patch_target_sync_lookup_bucket_type, + return_value=BucketType.ZONAL_HIERARCHICAL, + ) as mock_sync_lookup_bucket_type, + mock.patch( + patch_target_lookup_bucket_type, + return_value=BucketType.ZONAL_HIERARCHICAL, + ), + mock.patch(patch_target_create_mrd, mock_create_mrd), + mock.patch( + patch_target_gcsfs_cat_file, new_callable=mock.AsyncMock + ) as mock_cat_file, + ): + mocks = { + "sync_lookup_bucket_type": mock_sync_lookup_bucket_type, + "create_mrd": mock_create_mrd, + "downloader": mock_downloader, + "cat_file": mock_cat_file, + } + yield mocks + # Common assertion for all tests using this mock + mock_cat_file.assert_not_called() + + yield _zonal_mocks_factory + + +read_block_params = [ + # Read specific chunk + pytest.param(3, 10, None, json_data[3 : 3 + 10], id="offset=3, length=10"), + # Read from beginning up to length + pytest.param(0, 5, None, json_data[0:5], id="offset=0, length=5"), + # Read from offset to end (simulate large length) + pytest.param(15, 5000, None, json_data[15:], id="offset=15, length=large"), + # Read beyond end of file (should return empty bytes) + pytest.param(file_size + 10, 5, None, b"", id="offset>size, length=5"), + # Read exactly at the end (zero length) + pytest.param(file_size, 10, None, b"", id="offset=size, length=10"), + # Read with delimiter + pytest.param(1, 35, b"\n", lines[1], id="offset=1, length=35, delimiter=newline"), + pytest.param(0, 30, b"\n", lines[0], id="offset=0, length=35, delimiter=newline"), + pytest.param( + 0, 35, b"\n", lines[0] + lines[1], id="offset=0, length=35, delimiter=newline" + ), +] + + +def test_read_block_zb(extended_gcsfs, zonal_mocks, subtests): + for param in read_block_params: + with subtests.test(id=param.id): + offset, length, delimiter, expected_data = param.values + path = file_path + + with zonal_mocks(json_data) as mocks: + result = extended_gcsfs.read_block(path, offset, length, delimiter) + + assert result == expected_data + if mocks: + mocks["sync_lookup_bucket_type"].assert_called_once_with( + TEST_BUCKET + ) + if expected_data: + mocks["downloader"].download_ranges.assert_called_with( + [(offset, mock.ANY, mock.ANY)] + ) + else: + mocks["downloader"].download_ranges.assert_not_called() + + +def test_read_small_zb(extended_gcsfs, zonal_mocks): + csv_file = "2014-01-01.csv" + csv_file_path = f"{TEST_BUCKET}/{csv_file}" + csv_data = csv_files[csv_file] + + with zonal_mocks(csv_data) as mocks: + with extended_gcsfs.open(csv_file_path, "rb", block_size=10) as f: + out = [] + i = 1 + while True: + i += 1 + data = f.read(3) + if data == b"": + break + out.append(data) + assert extended_gcsfs.cat(csv_file_path) == b"".join(out) + # cache drop + assert len(f.cache.cache) < len(out) + if mocks: + mocks["sync_lookup_bucket_type"].assert_called_once_with(TEST_BUCKET) + + +def test_readline_zb(extended_gcsfs, zonal_mocks): + all_items = chain.from_iterable( + [files.items(), csv_files.items(), text_files.items()] + ) + for k, data in all_items: + with zonal_mocks(data): + with extended_gcsfs.open("/".join([TEST_BUCKET, k]), "rb") as f: + result = f.readline() + expected = data.split(b"\n")[0] + (b"\n" if data.count(b"\n") else b"") + assert result == expected + + +def test_readline_from_cache_zb(extended_gcsfs, zonal_mocks): + data = b"a,b\n11,22\n3,4" + if not extended_gcsfs.on_google: + with mock.patch.object( + extended_gcsfs, "_sync_lookup_bucket_type", return_value=BucketType.UNKNOWN + ): + with extended_gcsfs.open(a, "wb") as f: + f.write(data) + with zonal_mocks(data): + with extended_gcsfs.open(a, "rb") as f: + result = f.readline() + assert result == b"a,b\n" + assert f.loc == 4 + assert f.cache.cache == data + + result = f.readline() + assert result == b"11,22\n" + assert f.loc == 10 + assert f.cache.cache == data + + result = f.readline() + assert result == b"3,4" + assert f.loc == 13 + assert f.cache.cache == data + + +def test_readline_empty_zb(extended_gcsfs, zonal_mocks): + data = b"" + if not extended_gcsfs.on_google: + with mock.patch.object( + extended_gcsfs, "_sync_lookup_bucket_type", return_value=BucketType.UNKNOWN + ): + with extended_gcsfs.open(b, "wb") as f: + f.write(data) + with zonal_mocks(data): + with extended_gcsfs.open(b, "rb") as f: + result = f.readline() + assert result == data + + +def test_readline_blocksize_zb(extended_gcsfs, zonal_mocks): + data = b"ab\n" + b"a" * (2**18) + b"\nab" + if not extended_gcsfs.on_google: + with mock.patch.object( + extended_gcsfs, "_sync_lookup_bucket_type", return_value=BucketType.UNKNOWN + ): + with extended_gcsfs.open(c, "wb") as f: + f.write(data) + with zonal_mocks(data): + with extended_gcsfs.open(c, "rb", block_size=2**18) as f: + result = f.readline() + expected = b"ab\n" + assert result == expected + + result = f.readline() + expected = b"a" * (2**18) + b"\n" + assert result == expected + + result = f.readline() + expected = b"ab" + assert result == expected + + +@pytest.mark.parametrize( + "start,end,exp_offset,exp_length,exp_exc", + [ + (None, None, 0, file_size, None), # full file + (-10, None, file_size - 10, 10, None), # start negative + (10, -10, 10, file_size - 20, None), # end negative + (20, 20, 20, 0, None), # zero-length slice + (50, 40, None, None, ValueError), # end before start -> raises + (-200, None, None, None, ValueError), # offset negative -> raises + (file_size - 10, 200, file_size - 10, 10, None), # end > size clamps + ( + file_size + 10, + file_size + 20, + file_size + 10, + 0, + None, + ), # offset > size -> empty + ], +) +def test_process_limits_parametrized( + extended_gcsfs, start, end, exp_offset, exp_length, exp_exc +): + if exp_exc is not None: + with pytest.raises(exp_exc): + extended_gcsfs.sync_process_limits_to_offset_and_length( + file_path, start, end + ) + else: + offset, length = extended_gcsfs.sync_process_limits_to_offset_and_length( + file_path, start, end + ) + assert offset == exp_offset + assert length == exp_length + + +@pytest.mark.parametrize( + "exception_to_raise", + [ValueError, DataCorruption, Exception], +) +def test_mrd_exception_handling(extended_gcsfs, zonal_mocks, exception_to_raise): + """ + Tests that _cat_file correctly propagates exceptions from mrd.download_ranges. + """ + with zonal_mocks(json_data) as mocks: + if extended_gcsfs.on_google: + pytest.skip("Cannot mock exceptions on real GCS") + + # Configure the mock to raise a specified exception + if exception_to_raise is DataCorruption: + # The first argument is 'response', the message is in '*args' + mocks["downloader"].download_ranges.side_effect = exception_to_raise( + None, "Test exception raised" + ) + else: + mocks["downloader"].download_ranges.side_effect = exception_to_raise( + "Test exception raised" + ) + + with pytest.raises(exception_to_raise, match="Test exception raised"): + extended_gcsfs.read_block(file_path, 0, 10) + + mocks["downloader"].download_ranges.assert_called_once() + + +def test_mrd_stream_cleanup(extended_gcsfs, zonal_mocks): + """ + Tests that mrd stream is properly closed with file closure. + """ + with zonal_mocks(json_data) as mocks: + if not extended_gcsfs.on_google: + + def close_side_effect(): + mocks["downloader"].is_stream_open = False + + mocks["downloader"].close.side_effect = close_side_effect + + with extended_gcsfs.open(file_path, "rb") as f: + assert f.mrd is not None + + assert True is f.closed + assert False is f.mrd.is_stream_open diff --git a/gcsfs/tests/test_fuse.py b/gcsfs/tests/test_fuse.py index 1c0e1c0a..41d418ac 100644 --- a/gcsfs/tests/test_fuse.py +++ b/gcsfs/tests/test_fuse.py @@ -37,7 +37,7 @@ def test_fuse(gcs, fsspec_fuse_run): timeout = 20 n = 40 for i in range(n): - logging.debug(f"Attempt # {i + 1}/{n} to create lock file.") + logging.debug(f"Attempt # {i + 1} / {n} to create lock file.") try: open(os.path.join(mountpath, "lock"), "w").close() os.remove(os.path.join(mountpath, "lock")) diff --git a/gcsfs/tests/test_init.py b/gcsfs/tests/test_init.py new file mode 100644 index 00000000..4619de4f --- /dev/null +++ b/gcsfs/tests/test_init.py @@ -0,0 +1,67 @@ +import os +import sys + + +class TestConditionalImport: + def setup_method(self, method): + """Setup for each test method.""" + self.original_env = os.environ.get("GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT") + + # Snapshot original gcsfs modules + self.original_modules = { + name: mod for name, mod in sys.modules.items() if name.startswith("gcsfs") + } + + # Unload gcsfs modules to force re-import during the test + modules_to_remove = list(self.original_modules.keys()) + for name in modules_to_remove: + if name in sys.modules: + del sys.modules[name] + + def teardown_method(self, method): + """Teardown after each test method.""" + # Reset environment variable to its original state + if self.original_env is not None: + os.environ["GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT"] = self.original_env + elif "GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT" in os.environ: + del os.environ["GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT"] + + # Clear any gcsfs modules loaded/modified during this test + modules_to_remove = [name for name in sys.modules if name.startswith("gcsfs")] + for name in modules_to_remove: + if name in sys.modules: + del sys.modules[name] + + # Restore the original gcsfs modules from the snapshot to avoid side effect + # affecting other tests + sys.modules.update(self.original_modules) + + def test_experimental_env_unset(self): + """ + Tests gcsfs.GCSFileSystem is core.GCSFileSystem when + GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT is NOT set. + """ + if "GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT" in os.environ: + del os.environ["GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT"] + + import gcsfs + + assert ( + gcsfs.GCSFileSystem is gcsfs.core.GCSFileSystem + ), "Should be core.GCSFileSystem" + assert not hasattr( + gcsfs, "ExtendedGcsFileSystem" + ), "ExtendedGcsFileSystem should not be imported directly on gcsfs" + + def test_experimental_env_set(self): + """ + Tests gcsfs.GCSFileSystem is extended_gcsfs.ExtendedGcsFileSystem when + GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT IS set. + """ + os.environ["GCSFS_EXPERIMENTAL_ZB_HNS_SUPPORT"] = "true" + + import gcsfs + + assert ( + gcsfs.GCSFileSystem is gcsfs.extended_gcsfs.ExtendedGcsFileSystem + ), "Should be ExtendedGcsFileSystem" diff --git a/gcsfs/tests/test_zb_hns_utils.py b/gcsfs/tests/test_zb_hns_utils.py new file mode 100644 index 00000000..a64e6793 --- /dev/null +++ b/gcsfs/tests/test_zb_hns_utils.py @@ -0,0 +1,29 @@ +from unittest import mock + +import pytest + +from gcsfs import zb_hns_utils + + +@pytest.mark.asyncio +async def test_download_range(): + """ + Tests that download_range calls mrd.download_ranges with the correct + parameters and returns the data written to the buffer. + """ + offset = 10 + length = 20 + mock_mrd = mock.AsyncMock() + expected_data = b"test data from download" + + # Simulate the download_ranges method writing data to the buffer + async def mock_download_ranges(ranges): + _offset, _length, buffer = ranges[0] + buffer.write(expected_data) + + mock_mrd.download_ranges.side_effect = mock_download_ranges + + result = await zb_hns_utils.download_range(offset, length, mock_mrd) + + mock_mrd.download_ranges.assert_called_once_with([(offset, length, mock.ANY)]) + assert result == expected_data diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py new file mode 100644 index 00000000..648974e2 --- /dev/null +++ b/gcsfs/zb_hns_utils.py @@ -0,0 +1,13 @@ +from io import BytesIO + + +async def download_range(offset, length, mrd): + """ + Downloads a byte range from the file asynchronously. + """ + # If length = 0, mrd returns till end of file, so handle that case here + if length == 0: + return b"" + buffer = BytesIO() + await mrd.download_ranges([(offset, length, buffer)]) + return buffer.getvalue() diff --git a/gcsfs/zonal_file.py b/gcsfs/zonal_file.py new file mode 100644 index 00000000..93afb84c --- /dev/null +++ b/gcsfs/zonal_file.py @@ -0,0 +1,56 @@ +from fsspec import asyn +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + +from gcsfs.core import GCSFile + + +class ZonalFile(GCSFile): + """ + ZonalFile is subclass of GCSFile and handles data operations from + Zonal buckets only using a high-performance gRPC path. + """ + + def __init__(self, *args, **kwargs): + """ + Initializes the ZonalFile object. + """ + super().__init__(*args, **kwargs) + self.mrd = None + if "r" in self.mode: + self.mrd = asyn.sync( + self.gcsfs.loop, self._init_mrd, self.bucket, self.key, self.generation + ) + else: + raise NotImplementedError( + "Only read operations are currently supported for Zonal buckets." + ) + + async def _init_mrd(self, bucket_name, object_name, generation=None): + """ + Initializes the AsyncMultiRangeDownloader. + """ + return await AsyncMultiRangeDownloader.create_mrd( + self.gcsfs.grpc_client, bucket_name, object_name, generation + ) + + def _fetch_range(self, start, end): + """ + Overrides the default _fetch_range to implement the gRPC read path. + + """ + try: + return self.gcsfs.cat_file(self.path, start=start, end=end, mrd=self.mrd) + except RuntimeError as e: + if "not satisfiable" in str(e): + return b"" + raise + + def close(self): + """ + Closes the ZonalFile and the underlying AsyncMultiRangeDownloader. + """ + if self.mrd: + asyn.sync(self.gcsfs.loop, self.mrd.close) + super().close() diff --git a/requirements.txt b/requirements.txt index f82b0722..28e2cafa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ fsspec==2025.10.0 google-auth>=1.2 google-auth-oauthlib google-cloud-storage +google-cloud-storage-control requests