diff --git a/onvif/client.py b/onvif/client.py index 5f4f7b3..176d0df 100644 --- a/onvif/client.py +++ b/onvif/client.py @@ -6,20 +6,22 @@ import datetime as dt import logging import os.path -from typing import Any from collections.abc import Callable -import httpx -from httpx import AsyncClient, BasicAuth, DigestAuth +from typing import Any + +import zeep.helpers from zeep.cache import SqliteCache from zeep.client import AsyncClient as BaseZeepAsyncClient -import zeep.helpers from zeep.proxy import AsyncServiceProxy -from zeep.transports import AsyncTransport from zeep.wsdl import Document from zeep.wsse.username import UsernameToken +import aiohttp +import httpx +from aiohttp import BasicAuth, ClientSession, DigestAuthMiddleware, TCPConnector from onvif.definition import SERVICES from onvif.exceptions import ONVIFAuthError, ONVIFError, ONVIFTimeoutError +from requests import Response from .const import KEEPALIVE_EXPIRY from .managers import NotificationManager, PullPointManager @@ -29,13 +31,14 @@ from .util import ( create_no_verify_ssl_context, normalize_url, + obscure_user_pass_url, path_isfile, - utcnow, strip_user_pass_url, - obscure_user_pass_url, + utcnow, ) -from .wrappers import retry_connection_error # noqa: F401 +from .wrappers import retry_connection_error from .wsa import WsAddressingIfMissingPlugin +from .zeep_aiohttp import AIOHTTPTransport logger = logging.getLogger("onvif") logging.basicConfig(level=logging.INFO) @@ -48,7 +51,7 @@ _CONNECT_TIMEOUT = 30 _READ_TIMEOUT = 90 _WRITE_TIMEOUT = 90 -_HTTPX_LIMITS = httpx.Limits(keepalive_expiry=KEEPALIVE_EXPIRY) +# Keepalive is set on the connector, not in ClientTimeout _NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context() @@ -59,7 +62,7 @@ def wrapped(*args, **kwargs): try: return func(*args, **kwargs) except Exception as err: - raise ONVIFError(err) + raise ONVIFError(err) from err return wrapped @@ -102,20 +105,28 @@ def original_load(self, *args: Any, **kwargs: Any) -> None: return original_load(self, *args, **kwargs) -class AsyncTransportProtocolErrorHandler(AsyncTransport): - """Retry on remote protocol error. +class AsyncTransportProtocolErrorHandler(AIOHTTPTransport): + """ + Retry on remote protocol error. http://datatracker.ietf.org/doc/html/rfc2616#section-8.1.4 allows the server # to close the connection at any time, we treat this as a normal and try again # once since """ - @retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError) - async def post(self, address, message, headers): + @retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError) + async def post( + self, address: str, message: str, headers: dict[str, str] + ) -> httpx.Response: return await super().post(address, message, headers) - @retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError) - async def get(self, address, params, headers): + @retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError) + async def get( + self, + address: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Response: return await super().get(address, params, headers) @@ -162,7 +173,8 @@ def __init__(self, *args, **kwargs): self.set_ns_prefix("wsa", "http://www.w3.org/2005/08/addressing") def create_service(self, binding_name, address): - """Create a new ServiceProxy for the given binding name and address. + """ + Create a new ServiceProxy for the given binding name and address. :param binding_name: The QName of the binding :param address: The address of the endpoint """ @@ -170,9 +182,9 @@ def create_service(self, binding_name, address): binding = self.wsdl.bindings[binding_name] except KeyError: raise ValueError( - "No binding found with the given QName. Available bindings " - "are: %s" % (", ".join(self.wsdl.bindings.keys())) - ) + f"No binding found with the given QName. Available bindings " + f"are: {', '.join(self.wsdl.bindings.keys())}" + ) from None return AsyncServiceProxy(self, binding, address=address) @@ -223,7 +235,7 @@ def __init__( write_timeout: int | None = None, ) -> None: if not path_isfile(url): - raise ONVIFError("%s doesn`t exist!" % url) + raise ONVIFError(f"{url} doesn`t exist!") self.url = url self.xaddr = xaddr @@ -236,26 +248,28 @@ def __init__( self.dt_diff = dt_diff self.binding_name = binding_name # Create soap client - timeouts = httpx.Timeout( - _DEFAULT_TIMEOUT, - connect=_CONNECT_TIMEOUT, - read=read_timeout or _READ_TIMEOUT, - write=write_timeout or _WRITE_TIMEOUT, - ) - client = AsyncClient( - verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS + connector = TCPConnector( + ssl=_NO_VERIFY_SSL_CONTEXT, + keepalive_timeout=KEEPALIVE_EXPIRY, ) - # The wsdl client should never actually be used, but it is required - # to avoid creating another ssl context since the underlying code - # will try to create a new one if it doesn't exist. - wsdl_client = httpx.Client( - verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS + session = ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout( + total=_DEFAULT_TIMEOUT, + connect=_CONNECT_TIMEOUT, + sock_read=read_timeout or _READ_TIMEOUT, + ), ) self.transport = ( - AsyncTransportProtocolErrorHandler(client=client, wsdl_client=wsdl_client) + AsyncTransportProtocolErrorHandler( + session=session, + verify_ssl=False, + ) if no_cache - else AsyncTransportProtocolErrorHandler( - client=client, wsdl_client=wsdl_client, cache=SqliteCache() + else AIOHTTPTransport( + session=session, + verify_ssl=False, + cache=SqliteCache(), ) ) self.document: Document | None = None @@ -399,7 +413,8 @@ def __init__( self.to_dict = ONVIFService.to_dict self._snapshot_uris = {} - self._snapshot_client = AsyncClient(verify=_NO_VERIFY_SSL_CONTEXT) + self._snapshot_connector = TCPConnector(ssl=_NO_VERIFY_SSL_CONTEXT) + self._snapshot_client = ClientSession(connector=self._snapshot_connector) async def get_capabilities(self) -> dict[str, Any]: """Get device capabilities.""" @@ -531,7 +546,8 @@ async def create_notification_manager( async def close(self) -> None: """Close all transports.""" - await self._snapshot_client.aclose() + await self._snapshot_client.close() + await self._snapshot_connector.close() for service in self.services.values(): await service.close() @@ -572,42 +588,53 @@ async def get_snapshot( if uri is None: return None - auth = None + auth: BasicAuth | None = None + middlewares: tuple[DigestAuthMiddleware, ...] | None = None + if self.user and self.passwd: if basic_auth: auth = BasicAuth(self.user, self.passwd) else: - auth = DigestAuth(self.user, self.passwd) + # Use DigestAuthMiddleware for digest auth + middlewares = (DigestAuthMiddleware(self.user, self.passwd),) - response = await self._try_snapshot_uri(uri, auth) + response = await self._try_snapshot_uri(uri, auth=auth, middlewares=middlewares) + content = await response.read() - # If the request fails with a 401, make sure to strip any - # sample user/pass from the URL and try again + # If the request fails with a 401, strip user/pass from URL and retry if ( - response.status_code == 401 + response.status == 401 and (stripped_uri := strip_user_pass_url(uri)) and stripped_uri != uri ): - response = await self._try_snapshot_uri(stripped_uri, auth) + response = await self._try_snapshot_uri( + stripped_uri, auth=auth, middlewares=middlewares + ) + content = await response.read() - if response.status_code == 401: + if response.status == 401: raise ONVIFAuthError(f"Failed to authenticate to {uri}") - if response.status_code < 300: - return response.content + if response.status < 300: + return content return None async def _try_snapshot_uri( - self, uri: str, auth: BasicAuth | DigestAuth | None - ) -> httpx.Response: + self, + uri: str, + auth: BasicAuth | None = None, + middlewares: tuple[DigestAuthMiddleware, ...] | None = None, + ) -> aiohttp.ClientResponse: try: - return await self._snapshot_client.get(uri, auth=auth) - except httpx.TimeoutException as error: + return await self._snapshot_client.get( + uri, auth=auth, middlewares=middlewares + ) + except TimeoutError as error: raise ONVIFTimeoutError( f"Timed out fetching {obscure_user_pass_url(uri)}: {error}" ) from error - except httpx.RequestError as error: + except aiohttp.ClientError as error: raise ONVIFError( f"Error fetching {obscure_user_pass_url(uri)}: {error}" ) from error @@ -618,7 +645,7 @@ def get_definition( """Returns xaddr and wsdl of specified service""" # Check if the service is supported if name not in SERVICES: - raise ONVIFError("Unknown service %s" % name) + raise ONVIFError(f"Unknown service {name}") wsdl_file = SERVICES[name]["wsdl"] namespace = SERVICES[name]["ns"] @@ -629,14 +656,14 @@ def get_definition( wsdlpath = os.path.join(self.wsdl_dir, wsdl_file) if not path_isfile(wsdlpath): - raise ONVIFError("No such file: %s" % wsdlpath) + raise ONVIFError(f"No such file: {wsdlpath}") # XAddr for devicemgmt is fixed: if name == "devicemgmt": xaddr = "{}:{}/onvif/device_service".format( self.host if (self.host.startswith("http://") or self.host.startswith("https://")) - else "http://%s" % self.host, + else f"http://{self.host}", self.port, ) return xaddr, wsdlpath, binding_name diff --git a/onvif/managers.py b/onvif/managers.py index 3a5c419..8f14030 100644 --- a/onvif/managers.py +++ b/onvif/managers.py @@ -2,19 +2,18 @@ from __future__ import annotations -from abc import abstractmethod import asyncio import datetime as dt import logging -from typing import TYPE_CHECKING, Any +from abc import abstractmethod from collections.abc import Callable +from typing import TYPE_CHECKING, Any -import httpx -from httpx import TransportError from zeep.exceptions import Fault, XMLParseError, XMLSyntaxError from zeep.loader import parse_xml from zeep.wsdl.bindings.soap import SoapOperation +import aiohttp from onvif.exceptions import ONVIFError from .settings import DEFAULT_SETTINGS @@ -27,8 +26,8 @@ _RENEWAL_PERCENTAGE = 0.8 -SUBSCRIPTION_ERRORS = (Fault, asyncio.TimeoutError, TransportError) -RENEW_ERRORS = (ONVIFError, httpx.RequestError, XMLParseError, *SUBSCRIPTION_ERRORS) +SUBSCRIPTION_ERRORS = (Fault, asyncio.TimeoutError, aiohttp.ClientError) +RENEW_ERRORS = (ONVIFError, aiohttp.ClientError, XMLParseError, *SUBSCRIPTION_ERRORS) SUBSCRIPTION_RESTART_INTERVAL_ON_ERROR = dt.timedelta(seconds=40) # If the camera returns a subscription with a termination time that is less than @@ -87,7 +86,8 @@ async def stop(self) -> None: await self._subscription.Unsubscribe() async def shutdown(self) -> None: - """Shutdown the manager. + """ + Shutdown the manager. This method is irreversible. """ @@ -105,7 +105,7 @@ async def set_synchronization_point(self) -> float: """Set the synchronization point.""" try: await self._service.SetSynchronizationPoint() - except (Fault, asyncio.TimeoutError, TransportError, TypeError): + except (TimeoutError, Fault, aiohttp.ClientError, TypeError): logger.debug("%s: SetSynchronizationPoint failed", self._service.url) def _cancel_renewals(self) -> None: @@ -214,7 +214,8 @@ def __init__( super().__init__(device, interval, subscription_lost_callback) async def _start(self) -> float: - """Start the notification processor. + """ + Start the notification processor. Returns the next renewal call at time. """ @@ -290,7 +291,8 @@ class PullPointManager(BaseManager): """Manager for PullPoint.""" async def _start(self) -> float: - """Start the PullPoint manager. + """ + Start the PullPoint manager. Returns the next renewal call at time. """ diff --git a/onvif/util.py b/onvif/util.py index c6d0b17..41f0b9c 100644 --- a/onvif/util.py +++ b/onvif/util.py @@ -4,15 +4,17 @@ import contextlib import datetime as dt -from functools import lru_cache, partial import os import ssl +from functools import lru_cache, partial from typing import Any from urllib.parse import ParseResultBytes, urlparse, urlunparse -from yarl import URL -from multidict import CIMultiDict + from zeep.exceptions import Fault +from multidict import CIMultiDict +from yarl import URL + utcnow: partial[dt.datetime] = partial(dt.datetime.now, dt.timezone.utc) # This does blocking I/O (stat) so we cache the result @@ -23,7 +25,8 @@ def normalize_url(url: bytes | str | None) -> str | None: - """Normalize URL. + """ + Normalize URL. Some cameras respond with http://192.168.1.106:8106:8106/onvif/Subscription?Idx=43 https://github.com/home-assistant/core/issues/92603#issuecomment-1537213126 @@ -73,7 +76,8 @@ def stringify_onvif_error(error: Exception) -> str: def is_auth_error(error: Exception) -> bool: - """Return True if error is an authentication error. + """ + Return True if error is an authentication error. Most of the tested cameras do not return a proper error code when authentication fails, so we need to check the error message as well. @@ -90,7 +94,8 @@ def is_auth_error(error: Exception) -> bool: def create_no_verify_ssl_context() -> ssl.SSLContext: - """Return an SSL context that does not verify the server certificate. + """ + Return an SSL context that does not verify the server certificate. This is a copy of aiohttp's create_default_context() function, with the ssl verify turned off and old SSL versions enabled. @@ -113,6 +118,12 @@ def create_no_verify_ssl_context() -> ssl.SSLContext: def strip_user_pass_url(url: str) -> str: """Strip password from URL.""" parsed_url = URL(url) + + # First strip userinfo (user:pass@) from URL + if parsed_url.user or parsed_url.password: + parsed_url = parsed_url.with_user(None) + + # Then strip credentials from query parameters query = parsed_url.query new_query: CIMultiDict | None = None for key in _CREDENTIAL_KEYS: @@ -122,12 +133,23 @@ def strip_user_pass_url(url: str) -> str: new_query.popall(key) if new_query is not None: return str(parsed_url.with_query(new_query)) - return url + return str(parsed_url) def obscure_user_pass_url(url: str) -> str: """Obscure user and password from URL.""" parsed_url = URL(url) + + # First obscure userinfo if present + if parsed_url.user: + # Keep the user but obscure the password + if parsed_url.password: + parsed_url = parsed_url.with_password("********") + else: + # If only user is present, obscure it + parsed_url = parsed_url.with_user("********") + + # Then obscure credentials in query parameters query = parsed_url.query new_query: CIMultiDict | None = None for key in _CREDENTIAL_KEYS: @@ -138,4 +160,4 @@ def obscure_user_pass_url(url: str) -> str: new_query[key] = "********" if new_query is not None: return str(parsed_url.with_query(new_query)) - return url + return str(parsed_url) diff --git a/onvif/wrappers.py b/onvif/wrappers.py index 13578c5..b9cb8ac 100644 --- a/onvif/wrappers.py +++ b/onvif/wrappers.py @@ -3,12 +3,11 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable import logging +from collections.abc import Awaitable, Callable from typing import ParamSpec, TypeVar -from collections.abc import Callable -import httpx +import aiohttp from .const import BACKOFF_TIME, DEFAULT_ATTEMPTS @@ -19,14 +18,15 @@ def retry_connection_error( attempts: int = DEFAULT_ATTEMPTS, - exception: httpx.HTTPError = httpx.RequestError, + exception: type[Exception] = aiohttp.ClientError, ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: """Define a wrapper to retry on connection error.""" def _decorator_retry_connection_error( func: Callable[P, Awaitable[T]], ) -> Callable[P, Awaitable[T]]: - """Define a wrapper to retry on connection error. + """ + Define a wrapper to retry on connection error. The remote server is allowed to disconnect us any time so we need to retry the operation. @@ -40,11 +40,11 @@ async def _async_wrap_connection_error_retry( # type: ignore[return] return await func(*args, **kwargs) except exception as ex: # - # We should only need to retry on RemoteProtocolError but some cameras + # We should only need to retry on ServerDisconnectedError but some cameras # are flakey and sometimes do not respond to the Renew request so we - # retry on RequestError as well. + # retry on ClientError as well. # - # For RemoteProtocolError: + # For ServerDisconnectedError: # http://datatracker.ietf.org/doc/html/rfc2616#section-8.1.4 allows the server # to close the connection at any time, we treat this as a normal and try again # once since we do not want to declare the camera as not supporting PullPoint diff --git a/onvif/zeep_aiohttp.py b/onvif/zeep_aiohttp.py new file mode 100644 index 0000000..23d3f2f --- /dev/null +++ b/onvif/zeep_aiohttp.py @@ -0,0 +1,292 @@ +"""AIOHTTP transport for zeep.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from zeep.cache import SqliteCache +from zeep.transports import Transport +from zeep.utils import get_version +from zeep.wsdl.utils import etree_to_string + +import aiohttp +import httpx +from aiohttp import ClientResponse, ClientSession +from requests import Response + +if TYPE_CHECKING: + from lxml.etree import _Element + +_LOGGER = logging.getLogger(__name__) + + +class AIOHTTPTransport(Transport): + """Async transport using aiohttp.""" + + def __init__( + self, + session: ClientSession, + verify_ssl: bool = True, + proxy: str | None = None, + cache: SqliteCache | None = None, + ) -> None: + """ + Initialize the transport. + + Args: + session: The aiohttp ClientSession to use (required). The session's + timeout configuration will be used for all requests. + verify_ssl: Whether to verify SSL certificates + proxy: Proxy URL to use + + """ + super().__init__( + cache=cache, + timeout=session.timeout.total, + operation_timeout=session.timeout.sock_read, + ) + + # Override parent's session with aiohttp session + self.session = session + self.verify_ssl = verify_ssl + self.proxy = proxy + self._close_session = False # Never close a provided session + # Extract timeout from session + self._client_timeout = session.timeout + + async def __aenter__(self) -> AIOHTTPTransport: + """Enter async context.""" + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit async context.""" + + async def aclose(self) -> None: + """Close the transport session.""" + + def _aiohttp_to_httpx_response( + self, aiohttp_response: ClientResponse, content: bytes + ) -> httpx.Response: + """Convert aiohttp ClientResponse to httpx Response.""" + # Create httpx Response with the content + httpx_response = httpx.Response( + status_code=aiohttp_response.status, + headers=httpx.Headers(aiohttp_response.headers), + content=content, + request=httpx.Request( + method=aiohttp_response.method, + url=str(aiohttp_response.url), + ), + ) + + # Add encoding if available + if aiohttp_response.charset: + httpx_response._encoding = aiohttp_response.charset + + # Store cookies if any + if aiohttp_response.cookies: + for name, cookie in aiohttp_response.cookies.items(): + # httpx.Cookies.set only accepts name, value, domain, and path + httpx_response.cookies.set( + name, + cookie.value, + domain=cookie.get("domain") or "", + path=cookie.get("path") or "/", + ) + + return httpx_response + + def _aiohttp_to_requests_response( + self, aiohttp_response: ClientResponse, content: bytes + ) -> Response: + """Convert aiohttp ClientResponse directly to requests Response.""" + new = Response() + new._content = content + new.status_code = aiohttp_response.status + new.headers = dict(aiohttp_response.headers) + # Convert aiohttp cookies to requests format + if aiohttp_response.cookies: + for name, cookie in aiohttp_response.cookies.items(): + new.cookies.set( + name, + cookie.value, + domain=cookie.get("domain"), + path=cookie.get("path"), + ) + new.encoding = aiohttp_response.charset + return new + + async def post( + self, address: str, message: str, headers: dict[str, str] + ) -> httpx.Response: + """ + Perform async POST request. + + Args: + address: The URL to send the request to + message: The message to send + headers: HTTP headers to include + + Returns: + The httpx response object + + """ + return await self._post(address, message, headers) + + async def _post( + self, address: str, message: str, headers: dict[str, str] + ) -> httpx.Response: + """Internal POST implementation.""" + _LOGGER.debug("HTTP Post to %s:\n%s", address, message) + + # Set default headers + headers = headers or {} + headers.setdefault("User-Agent", f"Zeep/{get_version()}") + headers.setdefault("Content-Type", 'text/xml; charset="utf-8"') + + # Handle both str and bytes + if isinstance(message, str): + data = message.encode("utf-8") + else: + data = message + + try: + response = await self.session.post( + address, + data=data, + headers=headers, + proxy=self.proxy, + timeout=self._client_timeout, + ) + response.raise_for_status() + + # Read the content to log it + content = await response.read() + _LOGGER.debug( + "HTTP Response from %s (status: %d):\n%s", + address, + response.status, + content.decode("utf-8", errors="replace"), + ) + + # Convert to httpx Response + return self._aiohttp_to_httpx_response(response, content) + + except TimeoutError as exc: + raise TimeoutError(f"Request to {address} timed out") from exc + except aiohttp.ClientError as exc: + raise ConnectionError(f"Error connecting to {address}: {exc}") from exc + + async def post_xml( + self, address: str, envelope: _Element, headers: dict[str, str] + ) -> Response: + """ + Post XML envelope and return parsed response. + + Args: + address: The URL to send the request to + envelope: The XML envelope to send + headers: HTTP headers to include + + Returns: + A Response object compatible with zeep + + """ + message = etree_to_string(envelope) + response = await self.post(address, message, headers) + return self._httpx_to_requests_response(response) + + async def get( + self, + address: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Response: + """ + Perform async GET request. + + Args: + address: The URL to send the request to + params: Query parameters + headers: HTTP headers to include + + Returns: + A Response object compatible with zeep + + """ + return await self._get(address, params, headers) + + async def _get( + self, + address: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Response: + """Internal GET implementation.""" + _LOGGER.debug("HTTP Get from %s", address) + + # Set default headers + headers = headers or {} + headers.setdefault("User-Agent", f"Zeep/{get_version()}") + + try: + response = await self.session.get( + address, + params=params, + headers=headers, + proxy=self.proxy, + timeout=self._client_timeout, + ) + response.raise_for_status() + + # Read content + content = await response.read() + + _LOGGER.debug( + "HTTP Response from %s (status: %d)", + address, + response.status, + ) + + # Convert directly to requests.Response + return self._aiohttp_to_requests_response(response, content) + + except TimeoutError as exc: + raise TimeoutError(f"Request to {address} timed out") from exc + except aiohttp.ClientError as exc: + raise ConnectionError(f"Error connecting to {address}: {exc}") from exc + + def _httpx_to_requests_response(self, response: httpx.Response) -> Response: + """Convert an httpx.Response object to a requests.Response object""" + body = response.read() + + new = Response() + new._content = body + new.status_code = response.status_code + new.headers = response.headers + new.cookies = response.cookies + new.encoding = response.encoding + return new + + def load(self, url: str) -> bytes: + """ + Load content from URL synchronously. + + This method runs the async get method in a new event loop. + + Args: + url: The URL to load + + Returns: + The content as bytes + + """ + # Create a new event loop for sync operation + loop = asyncio.new_event_loop() + try: + response = loop.run_until_complete(self.get(url)) + return response.content + finally: + loop.close() diff --git a/pyproject.toml b/pyproject.toml index 0d11e2d..c3533f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,5 +5,8 @@ log_level="NOTSET" plugins = ["covdefaults"] +[tool.ruff] +target-version = "py310" + [build-system] requires = ['setuptools>=65.4.1', 'wheel'] diff --git a/requirements.txt b/requirements.txt index f21d7f2..74923bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # Package +aiohttp==3.12.9 ciso8601==2.3.2 httpx==0.28.1 zeep[async]==4.3.1 diff --git a/requirements_dev.txt b/requirements_dev.txt index a54395a..a23b832 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,14 +1,12 @@ # Package -r requirements.txt -# Examples -aiohttp==3.12.10 - # Dev pytest==8.3.5 pytest-cov==6.1.1 pytest-asyncio==0.26.0 covdefaults==2.3.0 +aioresponses==0.7.6 # pre-commit pre-commit==4.2.0 diff --git a/setup.py b/setup.py index b3563d0..f8113b4 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ version = open(version_path).read().strip() requires = [ + "aiohttp>=3.12.9", "httpx>=0.19.0,<1.0.0", "zeep[async]>=4.2.1,<5.0.0", "ciso8601>=2.1.3", diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py new file mode 100644 index 0000000..42dcdd1 --- /dev/null +++ b/tests/test_snapshot.py @@ -0,0 +1,395 @@ +"""Tests for snapshot functionality using aiohttp.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, Mock, patch + +import pytest_asyncio + +import aiohttp +import pytest +from aioresponses import aioresponses +from onvif import ONVIFCamera +from onvif.exceptions import ONVIFAuthError, ONVIFError, ONVIFTimeoutError + + +@pytest.fixture +def mock_aioresponse(): + """Return aioresponses fixture.""" + # Note: aioresponses will mock all ClientSession instances by default + with aioresponses(passthrough=["http://127.0.0.1:8123"]) as m: + yield m + + +@asynccontextmanager +async def create_test_camera( + host: str = "192.168.1.100", + port: int = 80, + user: str | None = "admin", + passwd: str | None = "password", # noqa: S107 +) -> AsyncGenerator[ONVIFCamera]: + """Create a test camera instance with context manager.""" + cam = ONVIFCamera(host, port, user, passwd) + try: + yield cam + finally: + await cam.close() + + +@pytest_asyncio.fixture +async def camera() -> AsyncGenerator[ONVIFCamera]: + """Create a test camera instance.""" + async with create_test_camera() as cam: + # Mock the device management service to avoid actual WSDL loading + with ( + patch.object(cam, "create_devicemgmt_service", new_callable=AsyncMock), + patch.object( + cam, "create_media_service", new_callable=AsyncMock + ) as mock_media, + ): + # Mock the media service to return snapshot URI + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(Uri="http://192.168.1.100/snapshot") + ) + mock_media.return_value = mock_service + yield cam + + +@pytest.mark.asyncio +async def test_get_snapshot_success_with_digest_auth( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test successful snapshot retrieval with digest authentication.""" + snapshot_data = b"fake_image_data" + + # Mock successful response + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=snapshot_data + ) + + # Get snapshot with digest auth (default) + result = await camera.get_snapshot("Profile1", basic_auth=False) + + assert result == snapshot_data + + # Check that the request was made + assert len(mock_aioresponse.requests) == 1 + request_key = next(iter(mock_aioresponse.requests.keys())) + assert str(request_key[1]).startswith("http://192.168.1.100/snapshot") + + +@pytest.mark.asyncio +async def test_get_snapshot_success_with_basic_auth( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test successful snapshot retrieval with basic authentication.""" + snapshot_data = b"fake_image_data" + + # Mock successful response + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=snapshot_data + ) + + # Get snapshot with basic auth + result = await camera.get_snapshot("Profile1", basic_auth=True) + + assert result == snapshot_data + + # Check that the request was made + assert len(mock_aioresponse.requests) == 1 + request_key = next(iter(mock_aioresponse.requests.keys())) + assert str(request_key[1]).startswith("http://192.168.1.100/snapshot") + + +@pytest.mark.asyncio +async def test_get_snapshot_auth_failure( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test snapshot retrieval with authentication failure.""" + # Mock 401 response + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=401, body=b"Unauthorized" + ) + + # Should raise ONVIFAuthError + with pytest.raises(ONVIFAuthError) as exc_info: + await camera.get_snapshot("Profile1") + + assert "Failed to authenticate" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_get_snapshot_with_user_pass_in_url( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test snapshot retrieval when URI contains credentials.""" + # Mock the media service to return URI with credentials + with patch.object( + camera, "create_media_service", new_callable=AsyncMock + ) as mock_media: + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(Uri="http://admin:password@192.168.1.100/snapshot") + ) + mock_media.return_value = mock_service + + # First request fails with 401 + mock_aioresponse.get( + "http://admin:password@192.168.1.100/snapshot", + status=401, + body=b"Unauthorized", + ) + # Second request succeeds (stripped URL) + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"image_data" + ) + + result = await camera.get_snapshot("Profile1") + + assert result == b"image_data" + # Should have made 2 requests - first with credentials in URL, second without + request_keys = list(mock_aioresponse.requests.keys()) + assert len(request_keys) == 2 + assert str(request_keys[0][1]) == "http://admin:password@192.168.1.100/snapshot" + assert str(request_keys[1][1]) == "http://192.168.1.100/snapshot" + + +@pytest.mark.asyncio +async def test_get_snapshot_timeout( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test snapshot retrieval timeout.""" + # Mock timeout by raising TimeoutError + mock_aioresponse.get( + "http://192.168.1.100/snapshot", exception=TimeoutError("Connection timeout") + ) + + with pytest.raises(ONVIFTimeoutError) as exc_info: + await camera.get_snapshot("Profile1") + + assert "Timed out fetching" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_get_snapshot_client_error( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test snapshot retrieval with client error.""" + # Mock client error + mock_aioresponse.get( + "http://192.168.1.100/snapshot", + exception=aiohttp.ClientError("Connection failed"), + ) + + with pytest.raises(ONVIFError) as exc_info: + await camera.get_snapshot("Profile1") + + assert "Error fetching" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_get_snapshot_no_uri_available(camera: ONVIFCamera) -> None: + """Test snapshot when no URI is available.""" + # Mock the media service to raise fault + with patch.object( + camera, "create_media_service", new_callable=AsyncMock + ) as mock_media: + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + + import zeep.exceptions + + mock_service.GetSnapshotUri = AsyncMock( + side_effect=zeep.exceptions.Fault("Snapshot not supported") + ) + mock_media.return_value = mock_service + + result = await camera.get_snapshot("Profile1") + + assert result is None + + +@pytest.mark.asyncio +async def test_get_snapshot_invalid_uri_response(camera: ONVIFCamera) -> None: + """Test snapshot when device returns invalid URI.""" + # Mock the media service to return invalid response + with patch.object( + camera, "create_media_service", new_callable=AsyncMock + ) as mock_media: + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + # Return response without Uri attribute + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(spec=[]) # No Uri attribute + ) + mock_media.return_value = mock_service + + result = await camera.get_snapshot("Profile1") + + assert result is None + + +@pytest.mark.asyncio +async def test_get_snapshot_404_error( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test snapshot retrieval with 404 error.""" + # Mock 404 response + mock_aioresponse.get("http://192.168.1.100/snapshot", status=404, body=b"Not Found") + + result = await camera.get_snapshot("Profile1") + + # Should return None for non-auth errors + assert result is None + + +@pytest.mark.asyncio +async def test_get_snapshot_uri_caching(camera: ONVIFCamera) -> None: + """Test that snapshot URI is cached after first retrieval.""" + # First call should fetch URI from service + uri = await camera.get_snapshot_uri("Profile1") + assert uri == "http://192.168.1.100/snapshot" + + # Mock the media service to ensure it's not called again + with patch.object( + camera, "create_media_service", new_callable=AsyncMock + ) as mock_media: + mock_media.side_effect = Exception("Should not be called") + + # Second call should use cached URI + uri2 = await camera.get_snapshot_uri("Profile1") + assert uri2 == "http://192.168.1.100/snapshot" + + # Mock media service should not have been called + mock_media.assert_not_called() + + +@pytest.mark.asyncio +async def test_snapshot_client_session_reuse( + camera: ONVIFCamera, mock_aioresponse: aioresponses +) -> None: + """Test that snapshot client session is reused across requests.""" + snapshot_data = b"fake_image_data" + + # Get reference to the snapshot client + snapshot_client = camera._snapshot_client + + # Mock multiple requests + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=snapshot_data + ) + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=snapshot_data + ) + + # Make multiple snapshot requests + result1 = await camera.get_snapshot("Profile1") + result2 = await camera.get_snapshot("Profile1") + + assert result1 == snapshot_data + assert result2 == snapshot_data + + # Verify same client session was used + assert camera._snapshot_client is snapshot_client + + +@pytest.mark.asyncio +async def test_get_snapshot_no_credentials(mock_aioresponse: aioresponses) -> None: + """Test snapshot retrieval when camera has no credentials.""" + async with create_test_camera(user=None, passwd=None) as cam: + with ( + patch.object(cam, "create_devicemgmt_service", new_callable=AsyncMock), + patch.object( + cam, "create_media_service", new_callable=AsyncMock + ) as mock_media, + ): + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(Uri="http://192.168.1.100/snapshot") + ) + mock_media.return_value = mock_service + + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"image_data" + ) + + result = await cam.get_snapshot("Profile1") + assert result == b"image_data" + + +@pytest.mark.asyncio +async def test_get_snapshot_with_digest_auth_multiple_requests( + mock_aioresponse: aioresponses, +) -> None: + """Test that digest auth works correctly across multiple requests.""" + async with create_test_camera() as cam: + with ( + patch.object(cam, "create_devicemgmt_service", new_callable=AsyncMock), + patch.object( + cam, "create_media_service", new_callable=AsyncMock + ) as mock_media, + ): + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(Uri="http://192.168.1.100/snapshot") + ) + mock_media.return_value = mock_service + + # Mock multiple successful responses + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"image1" + ) + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"image2" + ) + + # Get snapshots with digest auth + result1 = await cam.get_snapshot("Profile1", basic_auth=False) + result2 = await cam.get_snapshot("Profile1", basic_auth=False) + + assert result1 == b"image1" + assert result2 == b"image2" + # Check that 2 requests were made (they're grouped by URL in aioresponses) + request_list = next(iter(mock_aioresponse.requests.values())) + assert len(request_list) == 2 + + +@pytest.mark.asyncio +async def test_get_snapshot_mixed_auth_methods(mock_aioresponse: aioresponses) -> None: + """Test switching between basic and digest auth.""" + async with create_test_camera() as cam: + with ( + patch.object(cam, "create_devicemgmt_service", new_callable=AsyncMock), + patch.object( + cam, "create_media_service", new_callable=AsyncMock + ) as mock_media, + ): + mock_service = Mock() + mock_service.create_type = Mock(return_value=Mock()) + mock_service.GetSnapshotUri = AsyncMock( + return_value=Mock(Uri="http://192.168.1.100/snapshot") + ) + mock_media.return_value = mock_service + + # Mock responses + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"basic_auth_image" + ) + mock_aioresponse.get( + "http://192.168.1.100/snapshot", status=200, body=b"digest_auth_image" + ) + + # Test with basic auth + result1 = await cam.get_snapshot("Profile1", basic_auth=True) + assert result1 == b"basic_auth_image" + + # Test with digest auth + result2 = await cam.get_snapshot("Profile1", basic_auth=False) + assert result2 == b"digest_auth_image" diff --git a/tests/test_util.py b/tests/test_util.py index c3d36aa..b951a64 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -46,6 +46,14 @@ async def test_normalize_url_with_missing_url(): def test_strip_user_pass_url(): assert strip_user_pass_url("http://1.2.3.4/?user=foo&pass=bar") == "http://1.2.3.4/" assert strip_user_pass_url("http://1.2.3.4/") == "http://1.2.3.4/" + # Test with userinfo in URL + assert strip_user_pass_url("http://user:pass@1.2.3.4/") == "http://1.2.3.4/" + assert strip_user_pass_url("http://user@1.2.3.4/") == "http://1.2.3.4/" + # Test with both userinfo and query params + assert ( + strip_user_pass_url("http://user:pass@1.2.3.4/?username=foo&password=bar") + == "http://1.2.3.4/" + ) def test_obscure_user_pass_url(): @@ -54,3 +62,18 @@ def test_obscure_user_pass_url(): == "http://1.2.3.4/?user=********&pass=********" ) assert obscure_user_pass_url("http://1.2.3.4/") == "http://1.2.3.4/" + # Test with userinfo in URL + assert ( + obscure_user_pass_url("http://user:pass@1.2.3.4/") + == "http://user:********@1.2.3.4/" + ) + assert obscure_user_pass_url("http://user@1.2.3.4/") == "http://********@1.2.3.4/" + # Test with both userinfo and query params + assert ( + obscure_user_pass_url("http://user:pass@1.2.3.4/?username=foo&password=bar") + == "http://user:********@1.2.3.4/?username=********&password=********" + ) + assert ( + obscure_user_pass_url("http://user@1.2.3.4/?password=bar") + == "http://********@1.2.3.4/?password=********" + ) diff --git a/tests/test_zeep_transport.py b/tests/test_zeep_transport.py new file mode 100644 index 0000000..f01c97b --- /dev/null +++ b/tests/test_zeep_transport.py @@ -0,0 +1,871 @@ +"""Tests for AIOHTTPTransport to ensure compatibility with zeep's AsyncTransport.""" + +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp +import httpx +import pytest +from lxml import etree +from onvif.zeep_aiohttp import AIOHTTPTransport +from requests import Response as RequestsResponse + + +def create_mock_session(timeout=None): + """Create a mock aiohttp session with optional timeout.""" + mock_session = Mock(spec=aiohttp.ClientSession) + if timeout: + mock_session.timeout = timeout + else: + # Create a default timeout object + default_timeout = Mock(total=300, sock_read=None) + mock_session.timeout = default_timeout + return mock_session + + +@pytest.mark.asyncio +async def test_post_returns_httpx_response(): + """Test that post() returns an httpx.Response object.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock aiohttp session and response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml"} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com/service" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + + mock_content = b"test" + mock_aiohttp_response.read = AsyncMock(return_value=mock_content) + + mock_session.post = AsyncMock(return_value=mock_aiohttp_response) + + # Call post + result = await transport.post( + "http://example.com/service", + "test", + {"SOAPAction": "test"}, + ) + + # Verify result is httpx.Response + assert isinstance(result, httpx.Response) + assert result.status_code == 200 + assert result.read() == mock_content + + +@pytest.mark.asyncio +async def test_post_xml_returns_requests_response(): + """Test that post_xml() returns a requests.Response object.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock aiohttp session and response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml"} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com/service" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + + mock_content = b"test" + mock_aiohttp_response.read = AsyncMock(return_value=mock_content) + + mock_session.post = AsyncMock(return_value=mock_aiohttp_response) + + # Create XML envelope + envelope = etree.Element("Envelope") + body = etree.SubElement(envelope, "Body") + etree.SubElement(body, "Request").text = "test" + + # Call post_xml + result = await transport.post_xml( + "http://example.com/service", envelope, {"SOAPAction": "test"} + ) + + # Verify result is requests.Response + assert isinstance(result, RequestsResponse) + assert result.status_code == 200 + assert result.content == mock_content + + +@pytest.mark.asyncio +async def test_get_returns_requests_response(): + """Test that get() returns a requests.Response object.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock aiohttp session and response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml"} + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + + mock_content = b"test" + mock_aiohttp_response.read = AsyncMock(return_value=mock_content) + + mock_session.get = AsyncMock(return_value=mock_aiohttp_response) + + # Call get + result = await transport.get( + "http://example.com/wsdl", + params={"version": "1.0"}, + headers={"Accept": "text/xml"}, + ) + + # Verify result is requests.Response + assert isinstance(result, RequestsResponse) + assert result.status_code == 200 + assert result.content == mock_content + + +@pytest.mark.asyncio +async def test_context_manager(): + """Test async context manager doesn't close provided session.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Session should already be set + assert transport.session == mock_session + + async with transport: + assert transport.session == mock_session + + # Session should still be there after context (not closed) + assert transport.session == mock_session + + +@pytest.mark.asyncio +async def test_aclose(): + """Test aclose() method doesn't close provided session.""" + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.timeout = Mock(total=300, sock_read=None) + mock_session.close = AsyncMock() + transport = AIOHTTPTransport(session=mock_session) + + # Call aclose + await transport.aclose() + + # Verify session.close() was NOT called (we don't close provided sessions) + mock_session.close.assert_not_called() + + +def test_load_sync(): + """Test load() method works synchronously.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock the async get method + mock_response = Mock(spec=RequestsResponse) + mock_response.content = b"test" + + with patch.object(transport, "get", new=AsyncMock(return_value=mock_response)): + result = transport.load("http://example.com/wsdl") + + assert result == b"test" + + +@pytest.mark.asyncio +async def test_timeout_handling(): + """Test timeout errors are properly handled.""" + mock_session = create_mock_session(timeout=aiohttp.ClientTimeout(total=0.1)) + transport = AIOHTTPTransport(session=mock_session) + + # Mock session that times out + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(side_effect=TimeoutError()) + + transport.session = mock_session + + with pytest.raises(TimeoutError, match="Request to .* timed out"): + await transport.post( + "http://example.com/service", "test", {} + ) + + +@pytest.mark.asyncio +async def test_connection_error_handling(): + """Test connection errors are properly handled.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock session that fails + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(side_effect=aiohttp.ClientError("Connection failed")) + + transport.session = mock_session + + with pytest.raises(ConnectionError, match="Error connecting to"): + await transport.get("http://example.com/wsdl") + + +@pytest.mark.asyncio +async def test_constructor_parameters(): + """Test constructor accepts expected parameters.""" + # Test with minimal parameters + mock_session1 = create_mock_session() + transport1 = AIOHTTPTransport(session=mock_session1) + # Session's timeout should be used + assert transport1.session.timeout == mock_session1.timeout + assert transport1.verify_ssl is True + assert transport1.proxy is None + + # Test with all parameters + timeout = aiohttp.ClientTimeout(total=100, connect=20) + mock_session2 = Mock(spec=aiohttp.ClientSession) + mock_session2.timeout = timeout + transport2 = AIOHTTPTransport( + session=mock_session2, + verify_ssl=False, + proxy="http://proxy:8080", + ) + assert transport2.session == mock_session2 + assert transport2.verify_ssl is False + assert transport2.proxy == "http://proxy:8080" + + +@pytest.mark.asyncio +async def test_post_with_bytes_message(): + """Test post() handles bytes message correctly.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml"} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(return_value=mock_aiohttp_response) + transport.session = mock_session + + # Test with bytes message + result = await transport.post( + "http://example.com", b"", {"SOAPAction": "test"} + ) + assert isinstance(result, httpx.Response) + assert result.status_code == 200 + + +@pytest.mark.asyncio +async def test_get_with_none_params(): + """Test get() works with None params and headers.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml"} + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_aiohttp_response) + transport.session = mock_session + + # Test without params/headers (should work with None) + result = await transport.get("http://example.com/wsdl", None, None) + assert isinstance(result, RequestsResponse) + assert result.status_code == 200 + + +@pytest.mark.asyncio +async def test_user_agent_header(): + """Test that User-Agent header is set correctly like AsyncTransport.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + post_mock = AsyncMock(return_value=mock_aiohttp_response) + mock_session.post = post_mock + transport.session = mock_session + + await transport.post("http://example.com", "test", {}) + + # Check User-Agent was set + call_args = post_mock.call_args + headers = call_args[1]["headers"] + assert "User-Agent" in headers + assert headers["User-Agent"].startswith("Zeep/") + + +@pytest.mark.asyncio +async def test_custom_timeout_used(): + """Test custom timeout is used when set.""" + custom_timeout = aiohttp.ClientTimeout(total=10, connect=5) + mock_session = create_mock_session(timeout=custom_timeout) + transport = AIOHTTPTransport(session=mock_session) + + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + post_mock = AsyncMock(return_value=mock_aiohttp_response) + mock_session.post = post_mock + transport.session = mock_session + + await transport.post("http://example.com", "test", {}) + + # Check that custom timeout was used + call_args = post_mock.call_args + timeout = call_args[1]["timeout"] + assert timeout is not None + assert timeout == custom_timeout + + +@pytest.mark.asyncio +async def test_proxy_parameter(): + """Test proxy parameter is passed correctly.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session, proxy="http://proxy:8080") + + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.method = "GET" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + get_mock = AsyncMock(return_value=mock_aiohttp_response) + mock_session.get = get_mock + transport.session = mock_session + + await transport.get("http://example.com") + + # Check proxy was passed + call_args = get_mock.call_args + assert call_args[1]["proxy"] == "http://proxy:8080" + + +@pytest.mark.asyncio +async def test_verify_ssl_false(): + """Test verify_ssl=False is stored correctly.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session, verify_ssl=False) + + # verify_ssl should be stored + assert transport.verify_ssl is False + + +@pytest.mark.asyncio +async def test_verify_ssl_true(): + """Test verify_ssl=True is stored correctly.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session, verify_ssl=True) + + # verify_ssl should be stored + assert transport.verify_ssl is True + + +@pytest.mark.asyncio +async def test_response_encoding(): + """Test response encoding is properly handled.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response with specific encoding + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {"Content-Type": "text/xml; charset=iso-8859-1"} + mock_aiohttp_response.charset = "iso-8859-1" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_aiohttp_response) + transport.session = mock_session + + result = await transport.get("http://example.com") + + # Check encoding was preserved + assert result.encoding == "iso-8859-1" + + +@pytest.mark.asyncio +async def test_cookies_in_httpx_response(): + """Test cookies are properly transferred to httpx response.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock cookies + mock_cookie = Mock() + mock_cookie.value = "abc123" + mock_cookie.get.side_effect = lambda k: {"domain": ".example.com", "path": "/"}.get( + k + ) + + mock_cookies = Mock() + mock_cookies.items.return_value = [("session", mock_cookie)] + + # Mock response with cookies + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = mock_cookies + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(return_value=mock_aiohttp_response) + transport.session = mock_session + + # Test httpx response (from post) + httpx_result = await transport.post("http://example.com", "test", {}) + assert "session" in httpx_result.cookies + + +@pytest.mark.asyncio +async def test_cookies_in_requests_response(): + """Test cookies are properly transferred to requests response.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock cookies using SimpleCookie format + from http.cookies import SimpleCookie + + mock_cookies = SimpleCookie() + mock_cookies["session"] = "abc123" + + # Mock response with cookies + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = mock_cookies + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_aiohttp_response) + transport.session = mock_session + + # Test requests response (from get) + requests_result = await transport.get("http://example.com") + assert "session" in requests_result.cookies + assert requests_result.cookies["session"] == "abc123" + + +@pytest.mark.asyncio +async def test_inherited_transport_attributes(): + """Test that Transport base class attributes are available.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Should have logger attribute from Transport + assert hasattr(transport, "logger") + + # Should have cache attribute (though we set it to None) + assert hasattr(transport, "cache") + assert transport.cache is None + + # Should have operation_timeout attribute from parent + assert hasattr(transport, "operation_timeout") + assert transport.operation_timeout is None + + +@pytest.mark.asyncio +async def test_session_reuse(): + """Test transport reuses provided session.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session.get = AsyncMock(return_value=mock_aiohttp_response) + + # Make multiple requests + result1 = await transport.get("http://example.com") + result2 = await transport.get("http://example.com") + + assert result1.content == b"test" + assert result2.content == b"test" + + # Session should be reused + assert mock_session.get.call_count == 2 + + +def test_sync_load_creates_new_loop(): + """Test load() creates new event loop when called from async context.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response + mock_response = Mock(spec=RequestsResponse) + mock_response.content = b"" + + # This should work even if there's already an event loop + with patch.object(transport, "get", new=AsyncMock(return_value=mock_response)): + with patch("asyncio.new_event_loop") as mock_new_loop: + mock_loop = Mock() + mock_loop.run_until_complete.return_value = mock_response + mock_new_loop.return_value = mock_loop + + result = transport.load("http://example.com/wsdl") + + # Should have created new loop + mock_new_loop.assert_called_once() + mock_loop.close.assert_called_once() + assert result == b"" + + +@pytest.mark.asyncio +async def test_content_type_header_default(): + """Test default Content-Type header is set for POST.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + mock_aiohttp_response = Mock(spec=aiohttp.ClientResponse) + mock_aiohttp_response.status = 200 + mock_aiohttp_response.headers = {} + mock_aiohttp_response.method = "POST" + mock_aiohttp_response.url = "http://example.com" + mock_aiohttp_response.charset = "utf-8" + mock_aiohttp_response.cookies = {} + mock_aiohttp_response.raise_for_status = Mock() + mock_aiohttp_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + post_mock = AsyncMock(return_value=mock_aiohttp_response) + mock_session.post = post_mock + transport.session = mock_session + + await transport.post("http://example.com", "test", {}) + + # Check Content-Type was set + call_args = post_mock.call_args + headers = call_args[1]["headers"] + assert headers["Content-Type"] == 'text/xml; charset="utf-8"' + + +@pytest.mark.asyncio +async def test_provided_session_not_closed(): + """Test that provided session is not closed by context manager.""" + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.close = AsyncMock() + + transport = AIOHTTPTransport(session=mock_session) + + async with transport: + assert transport.session == mock_session + + # Provided session should not be closed + mock_session.close.assert_not_called() + assert transport.session == mock_session + + +@pytest.mark.asyncio +async def test_cookie_conversion_httpx_basic(): + """Test basic cookie conversion from aiohttp to httpx response.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Create aiohttp cookies + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + cookies["session"] = "abc123" + cookies["session"]["domain"] = ".example.com" + cookies["session"]["path"] = "/api" + cookies["session"]["secure"] = True + cookies["session"]["httponly"] = True + cookies["session"]["max-age"] = "3600" + + cookies["user"] = "john_doe" + cookies["user"]["domain"] = "example.com" + cookies["user"]["path"] = "/" + + # Mock aiohttp response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.method = "POST" + mock_response.url = "http://example.com" + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.post("http://example.com", "test", {}) + + # Verify cookies in httpx response + assert "session" in result.cookies + assert result.cookies["session"] == "abc123" + assert "user" in result.cookies + assert result.cookies["user"] == "john_doe" + + +@pytest.mark.asyncio +async def test_cookie_conversion_requests_basic(): + """Test basic cookie conversion from aiohttp to requests response.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Create aiohttp cookies + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + cookies["token"] = "xyz789" + cookies["token"]["domain"] = ".api.example.com" + cookies["token"]["path"] = "/v1" + cookies["token"]["secure"] = True + + # Mock aiohttp response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"test") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.get("http://api.example.com/v1/data") + + # Verify cookies in requests response + assert "token" in result.cookies + assert result.cookies["token"] == "xyz789" + + +@pytest.mark.asyncio +async def test_cookie_attributes_httpx(): + """Test that cookie attributes are properly preserved in httpx response.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Create cookie with all attributes + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + cookies["auth"] = "secret123" + cookies["auth"]["domain"] = ".secure.com" + cookies["auth"]["path"] = "/admin" + cookies["auth"]["secure"] = True + cookies["auth"]["httponly"] = True + cookies["auth"]["samesite"] = "Strict" + cookies["auth"]["max-age"] = "7200" + + # Mock response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.method = "POST" + mock_response.url = "https://secure.com/admin" + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"secure") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.post("https://secure.com/admin", "login", {}) + + # Check cookie exists + assert "auth" in result.cookies + assert result.cookies["auth"] == "secret123" + + # Note: httpx.Cookies doesn't expose all attributes directly, + # but they should be preserved internally for cookie jar operations + + +@pytest.mark.asyncio +async def test_multiple_cookies(): + """Test handling multiple cookies.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Create multiple cookies + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + for i in range(5): + cookie_name = f"cookie{i}" + cookies[cookie_name] = f"value{i}" + cookies[cookie_name]["domain"] = ".example.com" + cookies[cookie_name]["path"] = f"/path{i}" + + # Mock response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.method = "GET" + mock_response.url = "http://example.com" + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"multi") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.get("http://example.com") + + # Verify all cookies + for i in range(5): + cookie_name = f"cookie{i}" + assert cookie_name in result.cookies + assert result.cookies[cookie_name] == f"value{i}" + + +@pytest.mark.asyncio +async def test_empty_cookies(): + """Test handling when no cookies are present.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Mock response without cookies + from http.cookies import SimpleCookie + + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.method = "GET" + mock_response.url = "http://example.com" + mock_response.charset = "utf-8" + mock_response.cookies = SimpleCookie() # Empty cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"nocookies") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.get("http://example.com") + + # Verify empty cookies + assert len(result.cookies) == 0 + + +@pytest.mark.asyncio +async def test_cookie_encoding(): + """Test cookies with special characters.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + # Create cookies with special chars + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + cookies["data"] = "hello%20world%21" # URL encoded + cookies["unicode"] = "café" + + # Mock response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"encoded") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.get = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Make request + result = await transport.get("http://example.com") + + # Verify encoded cookies + assert "data" in result.cookies + assert result.cookies["data"] == "hello%20world%21" + assert "unicode" in result.cookies + assert result.cookies["unicode"] == "café" + + +@pytest.mark.asyncio +async def test_cookie_jar_type(): + """Test that cookies are stored in appropriate jar types.""" + mock_session = create_mock_session() + transport = AIOHTTPTransport(session=mock_session) + + from http.cookies import SimpleCookie + + cookies = SimpleCookie() + cookies["test"] = "value" + + # Mock response + mock_response = Mock(spec=aiohttp.ClientResponse) + mock_response.status = 200 + mock_response.headers = {} + mock_response.method = "POST" + mock_response.url = "http://example.com" + mock_response.charset = "utf-8" + mock_response.cookies = cookies + mock_response.raise_for_status = Mock() + mock_response.read = AsyncMock(return_value=b"jar") + + mock_session = Mock(spec=aiohttp.ClientSession) + mock_session.post = AsyncMock(return_value=mock_response) + transport.session = mock_session + + # Test httpx response + httpx_result = await transport.post("http://example.com", "test", {}) + assert isinstance(httpx_result.cookies, httpx.Cookies) + + # Test requests response + mock_session.get = AsyncMock(return_value=mock_response) + requests_result = await transport.get("http://example.com") + # Verify cookies are accessible in requests response + assert hasattr(requests_result.cookies, "__getitem__") + assert "test" in requests_result.cookies