diff --git a/HISTORY.md b/HISTORY.md index 64bba9d8..2a718f08 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,7 @@ ## Unreleased +- Added support for `timeout` and `retry` kwargs for `GSClient`. (Issue [#484](https://github.com/drivendataorg/cloudpathlib/issues/484), PR [#485](https://github.com/drivendataorg/cloudpathlib/pull/485), thanks @Mchristos) - Fixed `CloudPath(...) / other` to correctly attempt to fall back on `other`'s `__rtruediv__` implementation, in order to support classes that explicitly support the `/` with a `CloudPath` instance. Previously, this would always raise a `TypeError` if `other` were not a `str` or `PurePosixPath`. (PR [#479](https://github.com/drivendataorg/cloudpathlib/pull/479)) - Add `md5` property to `GSPath`, updated LocalGSPath to include `md5` property, updated mock_gs.MockBlob to include `md5_hash` property. - Fixed an uncaught exception on Azure Gen2 storage accounts with HNS enabled when used with `DefaultAzureCredential`. (Issue [#486](https://github.com/drivendataorg/cloudpathlib/issues/486)) diff --git a/cloudpathlib/gs/gsclient.py b/cloudpathlib/gs/gsclient.py index c66e661d..2816d88d 100644 --- a/cloudpathlib/gs/gsclient.py +++ b/cloudpathlib/gs/gsclient.py @@ -13,6 +13,7 @@ try: if TYPE_CHECKING: from google.auth.credentials import Credentials + from google.api_core.retry import Retry from google.auth.exceptions import DefaultCredentialsError from google.cloud.storage import Client as StorageClient @@ -45,6 +46,8 @@ def __init__( local_cache_dir: Optional[Union[str, os.PathLike]] = None, content_type_method: Optional[Callable] = mimetypes.guess_type, download_chunks_concurrently_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + retry: Optional["Retry"] = None, ): """Class constructor. Sets up a [`Storage Client`](https://googleapis.dev/python/storage/latest/client.html). @@ -85,6 +88,8 @@ def __init__( download_chunks_concurrently_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to [`download_chunks_concurrently`](https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.transfer_manager#google_cloud_storage_transfer_manager_download_chunks_concurrently) for sliced parallel downloads; Only available in `google-cloud-storage` version 2.7.0 or later, otherwise ignored and a warning is emitted. + timeout (Optional[float]): Cloud Storage [timeout value](https://cloud.google.com/python/docs/reference/storage/1.39.0/retry_timeout) + retry (Optional[google.api_core.retry.Retry]): Cloud Storage [retry configuration](https://cloud.google.com/python/docs/reference/storage/1.39.0/retry_timeout#configuring-retries) """ if application_credentials is None: application_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") @@ -102,6 +107,13 @@ def __init__( self.client = StorageClient.create_anonymous_client() self.download_chunks_concurrently_kwargs = download_chunks_concurrently_kwargs + self.blob_kwargs: dict[str, Any] = {} + if timeout is not None: + self.timeout: float = timeout + self.blob_kwargs["timeout"] = self.timeout + if retry is not None: + self.retry: Retry = retry + self.blob_kwargs["retry"] = self.retry super().__init__( local_cache_dir=local_cache_dir, @@ -129,7 +141,6 @@ def _download_file(self, cloud_path: GSPath, local_path: Union[str, os.PathLike] blob = bucket.get_blob(cloud_path.blob) local_path = Path(local_path) - if transfer_manager is not None and self.download_chunks_concurrently_kwargs is not None: transfer_manager.download_chunks_concurrently( blob, local_path, **self.download_chunks_concurrently_kwargs @@ -140,7 +151,7 @@ def _download_file(self, cloud_path: GSPath, local_path: Union[str, os.PathLike] "Ignoring `download_chunks_concurrently_kwargs` for version of google-cloud-storage that does not support them (<2.7.0)." ) - blob.download_to_filename(local_path) + blob.download_to_filename(local_path, **self.blob_kwargs) return local_path @@ -247,7 +258,7 @@ def _move_file(self, src: GSPath, dst: GSPath, remove_src: bool = True) -> GSPat dst_bucket = self.client.bucket(dst.bucket) src_blob = src_bucket.get_blob(src.blob) - src_bucket.copy_blob(src_blob, dst_bucket, dst.blob) + src_bucket.copy_blob(src_blob, dst_bucket, dst.blob, **self.blob_kwargs) if remove_src: src_blob.delete() @@ -280,7 +291,7 @@ def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: GSPath) content_type, _ = self.content_type_method(str(local_path)) extra_args["content_type"] = content_type - blob.upload_from_filename(str(local_path), **extra_args) + blob.upload_from_filename(str(local_path), **extra_args, **self.blob_kwargs) return cloud_path def _get_public_url(self, cloud_path: GSPath) -> str: diff --git a/tests/mock_clients/mock_gs.py b/tests/mock_clients/mock_gs.py index e5763652..4ecdee1e 100644 --- a/tests/mock_clients/mock_gs.py +++ b/tests/mock_clients/mock_gs.py @@ -57,12 +57,19 @@ def delete(self): path.unlink() delete_empty_parents_up_to_root(path=path, root=self.bucket) - def download_to_filename(self, filename): + def download_to_filename(self, filename, timeout=None, retry=None): + # if timeout is not None, assume that the test wants a timeout and throw it + if timeout is not None: + raise TimeoutError("Download timed out") + + # indicate that retry object made it through to the GS lib + if retry is not None: + retry.mocked_retries = 1 + from_path = self.bucket / self.name - to_path = Path(filename) + to_path = Path(filename) to_path.parent.mkdir(exist_ok=True, parents=True) - to_path.write_bytes(from_path.read_bytes()) def patch(self): @@ -84,7 +91,15 @@ def reload( ): pass - def upload_from_filename(self, filename, content_type=None): + def upload_from_filename(self, filename, content_type=None, timeout=None, retry=None): + # if timeout is not None, assume that the test wants a timeout and throw it + if timeout is not None: + raise TimeoutError("Upload timed out") + + # indicate that retry object made it through to the GS lib + if retry is not None: + retry.mocked_retries = 1 + data = Path(filename).read_bytes() path = self.bucket / self.name path.parent.mkdir(parents=True, exist_ok=True) @@ -131,7 +146,15 @@ def __init__(self, name, bucket_name, client=None): def blob(self, blob): return MockBlob(self.name, blob, client=self.client) - def copy_blob(self, blob, destination_bucket, new_name): + def copy_blob(self, blob, destination_bucket, new_name, timeout=None, retry=None): + # if timeout is not None, assume that the test wants a timeout and throw it + if timeout is not None: + raise TimeoutError("Copy timed out") + + # indicate that retry object made it through to the GS lib + if retry is not None: + retry.mocked_retries = 1 + data = (self.name / blob.name).read_bytes() dst = destination_bucket.name / new_name dst.parent.mkdir(exist_ok=True, parents=True) diff --git a/tests/test_gs_specific.py b/tests/test_gs_specific.py index 6af6c464..f17d0898 100644 --- a/tests/test_gs_specific.py +++ b/tests/test_gs_specific.py @@ -1,6 +1,9 @@ +from urllib.parse import urlparse, parse_qs + +from google.api_core import retry +from google.api_core import exceptions import pytest -from urllib.parse import urlparse, parse_qs from cloudpathlib import GSPath from cloudpathlib.local import LocalGSPath @@ -75,3 +78,34 @@ def _calculate_b64_wrapped_md5_hash(contents: str) -> str: p: GSPath = gs_rig.create_cloud_path("dir_0/file0_0.txt") p.write_text(contents) assert p.md5 == expected_hash + + +def test_timeout_and_retry(gs_rig): + custom_retry = retry.Retry( + timeout=0.50, + predicate=retry.if_exception_type(exceptions.ServerError), + ) + + fast_timeout_client = gs_rig.client_class(timeout=0.00001, retry=custom_retry) + + with pytest.raises(Exception) as exc_info: + p = gs_rig.create_cloud_path("dir_0/file0_0.txt", client=fast_timeout_client) + p.write_text("hello world " * 10000) + + assert "timed out" in str(exc_info.value) + + # can't force retries to happen in live cloud tests, so skip + if not gs_rig.live_server: + custom_retry = retry.Retry( + initial=1.0, + multiplier=1.0, + timeout=15.0, + predicate=retry.if_exception_type(exceptions.ServerError), + ) + + p = gs_rig.create_cloud_path( + "dir_0/file0_0.txt", client=gs_rig.client_class(retry=custom_retry) + ) + p.write_text("hello world") + + assert custom_retry.mocked_retries == 1