diff --git a/onvif/client.py b/onvif/client.py index fb072c1..a6aa4a4 100644 --- a/onvif/client.py +++ b/onvif/client.py @@ -7,7 +7,7 @@ import logging import os.path from collections.abc import Callable -from typing import Any +from typing import Any, TypeVar import zeep.helpers from zeep.cache import SqliteCache @@ -164,6 +164,27 @@ def _load_document() -> DocumentWithDeferredLoad: return document +_T = TypeVar("_T") + + +def handle_snapshot_errors(func: Callable[..., _T]) -> Callable[..., _T]: + """Decorator to handle snapshot URI errors.""" + + async def wrapper(self, uri: str, *args: Any, **kwargs: Any) -> _T: + try: + return await func(self, uri, *args, **kwargs) + except TimeoutError as error: + raise ONVIFTimeoutError( + f"Timed out fetching {obscure_user_pass_url(uri)}: {error}" + ) from error + except aiohttp.ClientError as error: + raise ONVIFError( + f"Error fetching {obscure_user_pass_url(uri)}: {error}" + ) from error + + return wrapper + + class ZeepAsyncClient(BaseZeepAsyncClient): """Overwrite create_service method to be async.""" @@ -601,7 +622,7 @@ async def get_snapshot( middlewares = (DigestAuthMiddleware(self.user, self.passwd),) response = await self._try_snapshot_uri(uri, auth=auth, middlewares=middlewares) - content = await response.read() + content = await self._try_read_snapshot_content(uri, response) # If the request fails with a 401, strip user/pass from URL and retry if ( @@ -612,7 +633,7 @@ async def get_snapshot( response = await self._try_snapshot_uri( stripped_uri, auth=auth, middlewares=middlewares ) - content = await response.read() + content = await self._try_read_snapshot_content(uri, response) if response.status == 401: raise ONVIFAuthError(f"Failed to authenticate to {uri}") @@ -622,24 +643,23 @@ async def get_snapshot( return None + @handle_snapshot_errors + async def _try_read_snapshot_content( + self, + uri: str, + response: aiohttp.ClientResponse, + ) -> bytes: + """Try to read the snapshot URI.""" + return await response.read() + + @handle_snapshot_errors async def _try_snapshot_uri( self, uri: str, auth: BasicAuth | None = None, middlewares: tuple[DigestAuthMiddleware, ...] | None = None, ) -> aiohttp.ClientResponse: - try: - return await self._snapshot_client.get( - uri, auth=auth, middlewares=middlewares - ) - except TimeoutError as error: - raise ONVIFTimeoutError( - f"Timed out fetching {obscure_user_pass_url(uri)}: {error}" - ) from error - except aiohttp.ClientError as error: - raise ONVIFError( - f"Error fetching {obscure_user_pass_url(uri)}: {error}" - ) from error + return await self._snapshot_client.get(uri, auth=auth, middlewares=middlewares) def get_definition( self, name: str, port_type: str | None = None diff --git a/onvif/zeep_aiohttp.py b/onvif/zeep_aiohttp.py index 1ec0acf..82f5bd6 100644 --- a/onvif/zeep_aiohttp.py +++ b/onvif/zeep_aiohttp.py @@ -171,6 +171,9 @@ async def _post( # Convert to httpx Response return self._aiohttp_to_httpx_response(response, content) + except RuntimeError as exc: + # Handle RuntimeError which may occur if the session is closed + raise RuntimeError(f"Failed to post to {address}: {exc}") from exc except TimeoutError as exc: raise TimeoutError(f"Request to {address} timed out") from exc @@ -248,6 +251,9 @@ async def _get( # Convert directly to requests.Response return self._aiohttp_to_requests_response(response, content) + except RuntimeError as exc: + # Handle RuntimeError which may occur if the session is closed + raise RuntimeError(f"Failed to get from {address}: {exc}") from exc except TimeoutError as exc: raise TimeoutError(f"Request to {address} timed out") from exc