Skip to content

Commit e133748

Browse files
committed
Migrate to using aiohttp
1 parent 8dd6ba0 commit e133748

File tree

2 files changed

+101
-87
lines changed

2 files changed

+101
-87
lines changed

onvif/client.py

Lines changed: 94 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66
import datetime as dt
77
import logging
88
import os.path
9-
from typing import Any
109
from collections.abc import Callable
11-
import httpx
12-
from httpx import AsyncClient, BasicAuth, DigestAuth
10+
from typing import Any
11+
12+
import zeep.helpers
1313
from zeep.cache import SqliteCache
1414
from zeep.client import AsyncClient as BaseZeepAsyncClient
15-
import zeep.helpers
1615
from zeep.proxy import AsyncServiceProxy
17-
from zeep.transports import AsyncTransport
1816
from zeep.wsdl import Document
1917
from zeep.wsse.username import UsernameToken
2018

19+
import aiohttp
20+
import httpx
21+
from aiohttp import BasicAuth, ClientSession, DigestAuthMiddleware, TCPConnector
2122
from onvif.definition import SERVICES
2223
from onvif.exceptions import ONVIFAuthError, ONVIFError, ONVIFTimeoutError
24+
from requests import Response
2325

2426
from .const import KEEPALIVE_EXPIRY
2527
from .managers import NotificationManager, PullPointManager
@@ -29,13 +31,14 @@
2931
from .util import (
3032
create_no_verify_ssl_context,
3133
normalize_url,
34+
obscure_user_pass_url,
3235
path_isfile,
33-
utcnow,
3436
strip_user_pass_url,
35-
obscure_user_pass_url,
37+
utcnow,
3638
)
37-
from .wrappers import retry_connection_error # noqa: F401
39+
from .wrappers import retry_connection_error
3840
from .wsa import WsAddressingIfMissingPlugin
41+
from .zeep_aiohttp import AIOHTTPTransport
3942

4043
logger = logging.getLogger("onvif")
4144
logging.basicConfig(level=logging.INFO)
@@ -48,7 +51,7 @@
4851
_CONNECT_TIMEOUT = 30
4952
_READ_TIMEOUT = 90
5053
_WRITE_TIMEOUT = 90
51-
_HTTPX_LIMITS = httpx.Limits(keepalive_expiry=KEEPALIVE_EXPIRY)
54+
_KEEPALIVE_TIMEOUT = aiohttp.ClientTimeout(sock_keepalive=KEEPALIVE_EXPIRY)
5255
_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context()
5356

5457

@@ -59,7 +62,7 @@ def wrapped(*args, **kwargs):
5962
try:
6063
return func(*args, **kwargs)
6164
except Exception as err:
62-
raise ONVIFError(err)
65+
raise ONVIFError(err) from err
6366

6467
return wrapped
6568

@@ -80,7 +83,7 @@ def __init__(self, user, passw, dt_diff=None, **kwargs):
8083
def apply(self, envelope, headers):
8184
old_created = self.created
8285
if self.created is None:
83-
self.created = dt.datetime.now(tz=dt.timezone.utc).replace(tzinfo=None)
86+
self.created = dt.datetime.now(tz=dt.UTC).replace(tzinfo=None)
8487
if self.dt_diff is not None:
8588
self.created += self.dt_diff
8689
result = super().apply(envelope, headers)
@@ -102,20 +105,28 @@ def original_load(self, *args: Any, **kwargs: Any) -> None:
102105
return original_load(self, *args, **kwargs)
103106

104107

105-
class AsyncTransportProtocolErrorHandler(AsyncTransport):
106-
"""Retry on remote protocol error.
108+
class AsyncTransportProtocolErrorHandler(AIOHTTPTransport):
109+
"""
110+
Retry on remote protocol error.
107111
108112
http://datatracker.ietf.org/doc/html/rfc2616#section-8.1.4 allows the server
109113
# to close the connection at any time, we treat this as a normal and try again
110114
# once since
111115
"""
112116

113-
@retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError)
114-
async def post(self, address, message, headers):
117+
@retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError)
118+
async def post(
119+
self, address: str, message: str, headers: dict[str, str]
120+
) -> httpx.Response:
115121
return await super().post(address, message, headers)
116122

117-
@retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError)
118-
async def get(self, address, params, headers):
123+
@retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError)
124+
async def get(
125+
self,
126+
address: str,
127+
params: dict[str, Any] | None = None,
128+
headers: dict[str, str] | None = None,
129+
) -> Response:
119130
return await super().get(address, params, headers)
120131

121132

@@ -162,17 +173,18 @@ def __init__(self, *args, **kwargs):
162173
self.set_ns_prefix("wsa", "http://www.w3.org/2005/08/addressing")
163174

164175
def create_service(self, binding_name, address):
165-
"""Create a new ServiceProxy for the given binding name and address.
176+
"""
177+
Create a new ServiceProxy for the given binding name and address.
166178
:param binding_name: The QName of the binding
167179
:param address: The address of the endpoint
168180
"""
169181
try:
170182
binding = self.wsdl.bindings[binding_name]
171183
except KeyError:
172184
raise ValueError(
173-
"No binding found with the given QName. Available bindings "
174-
"are: %s" % (", ".join(self.wsdl.bindings.keys()))
175-
)
185+
f"No binding found with the given QName. Available bindings "
186+
f"are: {', '.join(self.wsdl.bindings.keys())}"
187+
) from None
176188
return AsyncServiceProxy(self, binding, address=address)
177189

178190

@@ -223,7 +235,7 @@ def __init__(
223235
write_timeout: int | None = None,
224236
) -> None:
225237
if not path_isfile(url):
226-
raise ONVIFError("%s doesn`t exist!" % url)
238+
raise ONVIFError(f"{url} doesn`t exist!")
227239

228240
self.url = url
229241
self.xaddr = xaddr
@@ -236,28 +248,28 @@ def __init__(
236248
self.dt_diff = dt_diff
237249
self.binding_name = binding_name
238250
# Create soap client
239-
timeouts = httpx.Timeout(
240-
_DEFAULT_TIMEOUT,
241-
connect=_CONNECT_TIMEOUT,
242-
read=read_timeout or _READ_TIMEOUT,
243-
write=write_timeout or _WRITE_TIMEOUT,
251+
timeout_seconds = _DEFAULT_TIMEOUT
252+
operation_timeout = read_timeout or _READ_TIMEOUT
253+
connector = TCPConnector(
254+
ssl=_NO_VERIFY_SSL_CONTEXT,
255+
keepalive_timeout=KEEPALIVE_EXPIRY,
244256
)
245-
client = AsyncClient(
246-
verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS
257+
session = ClientSession(
258+
connector=connector,
259+
timeout=aiohttp.ClientTimeout(
260+
total=timeout_seconds,
261+
connect=_CONNECT_TIMEOUT,
262+
sock_read=operation_timeout,
263+
),
247264
)
248-
# The wsdl client should never actually be used, but it is required
249-
# to avoid creating another ssl context since the underlying code
250-
# will try to create a new one if it doesn't exist.
251-
wsdl_client = httpx.Client(
252-
verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS
253-
)
254-
self.transport = (
255-
AsyncTransportProtocolErrorHandler(client=client, wsdl_client=wsdl_client)
256-
if no_cache
257-
else AsyncTransportProtocolErrorHandler(
258-
client=client, wsdl_client=wsdl_client, cache=SqliteCache()
259-
)
265+
self.transport = AsyncTransportProtocolErrorHandler(
266+
session=session,
267+
timeout=timeout_seconds,
268+
operation_timeout=operation_timeout,
269+
verify_ssl=False,
260270
)
271+
if not no_cache:
272+
self.transport.cache = SqliteCache()
261273
self.document: Document | None = None
262274
self.zeep_client_authless: ZeepAsyncClient | None = None
263275
self.ws_client_authless: AsyncServiceProxy | None = None
@@ -399,7 +411,8 @@ def __init__(
399411
self.to_dict = ONVIFService.to_dict
400412

401413
self._snapshot_uris = {}
402-
self._snapshot_client = AsyncClient(verify=_NO_VERIFY_SSL_CONTEXT)
414+
self._snapshot_connector = TCPConnector(ssl=_NO_VERIFY_SSL_CONTEXT)
415+
self._snapshot_client = ClientSession(connector=self._snapshot_connector)
403416

404417
async def get_capabilities(self) -> dict[str, Any]:
405418
"""Get device capabilities."""
@@ -531,7 +544,7 @@ async def create_notification_manager(
531544

532545
async def close(self) -> None:
533546
"""Close all transports."""
534-
await self._snapshot_client.aclose()
547+
await self._snapshot_client.close()
535548
for service in self.services.values():
536549
await service.close()
537550

@@ -572,42 +585,54 @@ async def get_snapshot(
572585
if uri is None:
573586
return None
574587

588+
# Create a new session with appropriate auth
589+
connector = TCPConnector(ssl=_NO_VERIFY_SSL_CONTEXT)
590+
middlewares = []
575591
auth = None
592+
576593
if self.user and self.passwd:
577594
if basic_auth:
578595
auth = BasicAuth(self.user, self.passwd)
579596
else:
580-
auth = DigestAuth(self.user, self.passwd)
581-
582-
response = await self._try_snapshot_uri(uri, auth)
583-
584-
# If the request fails with a 401, make sure to strip any
585-
# sample user/pass from the URL and try again
586-
if (
587-
response.status_code == 401
588-
and (stripped_uri := strip_user_pass_url(uri))
589-
and stripped_uri != uri
590-
):
591-
response = await self._try_snapshot_uri(stripped_uri, auth)
597+
# Use DigestAuthMiddleware for digest auth
598+
middlewares.append(DigestAuthMiddleware(self.user, self.passwd))
599+
600+
async with ClientSession(
601+
connector=connector, auth=auth, middlewares=middlewares
602+
) as session:
603+
response = await self._try_snapshot_uri_with_session(session, uri)
604+
content = await response.read()
605+
606+
# If the request fails with a 401, make sure to strip any
607+
# sample user/pass from the URL and try again
608+
if (
609+
response.status == 401
610+
and (stripped_uri := strip_user_pass_url(uri))
611+
and stripped_uri != uri
612+
):
613+
response = await self._try_snapshot_uri_with_session(
614+
session, stripped_uri
615+
)
616+
content = await response.read()
592617

593-
if response.status_code == 401:
594-
raise ONVIFAuthError(f"Failed to authenticate to {uri}")
618+
if response.status == 401:
619+
raise ONVIFAuthError(f"Failed to authenticate to {uri}")
595620

596-
if response.status_code < 300:
597-
return response.content
621+
if response.status < 300:
622+
return content
598623

599-
return None
624+
return None
600625

601-
async def _try_snapshot_uri(
602-
self, uri: str, auth: BasicAuth | DigestAuth | None
603-
) -> httpx.Response:
626+
async def _try_snapshot_uri_with_session(
627+
self, session: ClientSession, uri: str
628+
) -> aiohttp.ClientResponse:
604629
try:
605-
return await self._snapshot_client.get(uri, auth=auth)
606-
except httpx.TimeoutException as error:
630+
return await session.get(uri)
631+
except TimeoutError as error:
607632
raise ONVIFTimeoutError(
608633
f"Timed out fetching {obscure_user_pass_url(uri)}: {error}"
609634
) from error
610-
except httpx.RequestError as error:
635+
except aiohttp.ClientError as error:
611636
raise ONVIFError(
612637
f"Error fetching {obscure_user_pass_url(uri)}: {error}"
613638
) from error
@@ -618,7 +643,7 @@ def get_definition(
618643
"""Returns xaddr and wsdl of specified service"""
619644
# Check if the service is supported
620645
if name not in SERVICES:
621-
raise ONVIFError("Unknown service %s" % name)
646+
raise ONVIFError(f"Unknown service {name}")
622647
wsdl_file = SERVICES[name]["wsdl"]
623648
namespace = SERVICES[name]["ns"]
624649

@@ -629,14 +654,14 @@ def get_definition(
629654

630655
wsdlpath = os.path.join(self.wsdl_dir, wsdl_file)
631656
if not path_isfile(wsdlpath):
632-
raise ONVIFError("No such file: %s" % wsdlpath)
657+
raise ONVIFError(f"No such file: {wsdlpath}")
633658

634659
# XAddr for devicemgmt is fixed:
635660
if name == "devicemgmt":
636661
xaddr = "{}:{}/onvif/device_service".format(
637662
self.host
638663
if (self.host.startswith("http://") or self.host.startswith("https://"))
639-
else "http://%s" % self.host,
664+
else f"http://{self.host}",
640665
self.port,
641666
)
642667
return xaddr, wsdlpath, binding_name

onvif/zeep_aiohttp.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,13 @@ def _aiohttp_to_httpx_response(
9898
# Store cookies if any
9999
if aiohttp_response.cookies:
100100
for cookie in aiohttp_response.cookies.values():
101-
# Extract all cookie attributes
102-
cookie_attrs = {}
103-
if cookie.get("domain"):
104-
cookie_attrs["domain"] = cookie.get("domain")
105-
if cookie.get("path"):
106-
cookie_attrs["path"] = cookie.get("path")
107-
if cookie.get("secure"):
108-
cookie_attrs["secure"] = True
109-
if cookie.get("httponly"):
110-
cookie_attrs["httpOnly"] = True
111-
if cookie.get("max-age"):
112-
cookie_attrs["max_age"] = int(cookie.get("max-age"))
113-
if cookie.get("expires"):
114-
cookie_attrs["expires"] = cookie.get("expires")
115-
if cookie.get("samesite"):
116-
cookie_attrs["samesite"] = cookie.get("samesite")
117-
118-
httpx_response.cookies.set(cookie.key, cookie.value, **cookie_attrs)
101+
# httpx.Cookies.set only accepts name, value, domain, and path
102+
httpx_response.cookies.set(
103+
cookie.key,
104+
cookie.value,
105+
domain=cookie.get("domain", ""),
106+
path=cookie.get("path", "/"),
107+
)
119108

120109
return httpx_response
121110

0 commit comments

Comments
 (0)