Skip to content

Commit 4dc87bd

Browse files
authored
Migrate to using aiohttp (#115)
1 parent 4340e61 commit 4dc87bd

12 files changed

+1721
-86
lines changed

onvif/client.py

Lines changed: 84 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66
import datetime as dt
77
import logging
88
import os.path
9-
from typing import Any
109
from collections.abc import Callable
11-
import httpx
12-
from httpx import AsyncClient, BasicAuth, DigestAuth
10+
from typing import Any
11+
12+
import zeep.helpers
1313
from zeep.cache import SqliteCache
1414
from zeep.client import AsyncClient as BaseZeepAsyncClient
15-
import zeep.helpers
1615
from zeep.proxy import AsyncServiceProxy
17-
from zeep.transports import AsyncTransport
1816
from zeep.wsdl import Document
1917
from zeep.wsse.username import UsernameToken
2018

19+
import aiohttp
20+
import httpx
21+
from aiohttp import BasicAuth, ClientSession, DigestAuthMiddleware, TCPConnector
2122
from onvif.definition import SERVICES
2223
from onvif.exceptions import ONVIFAuthError, ONVIFError, ONVIFTimeoutError
24+
from requests import Response
2325

2426
from .const import KEEPALIVE_EXPIRY
2527
from .managers import NotificationManager, PullPointManager
@@ -29,13 +31,14 @@
2931
from .util import (
3032
create_no_verify_ssl_context,
3133
normalize_url,
34+
obscure_user_pass_url,
3235
path_isfile,
33-
utcnow,
3436
strip_user_pass_url,
35-
obscure_user_pass_url,
37+
utcnow,
3638
)
37-
from .wrappers import retry_connection_error # noqa: F401
39+
from .wrappers import retry_connection_error
3840
from .wsa import WsAddressingIfMissingPlugin
41+
from .zeep_aiohttp import AIOHTTPTransport
3942

4043
logger = logging.getLogger("onvif")
4144
logging.basicConfig(level=logging.INFO)
@@ -48,7 +51,7 @@
4851
_CONNECT_TIMEOUT = 30
4952
_READ_TIMEOUT = 90
5053
_WRITE_TIMEOUT = 90
51-
_HTTPX_LIMITS = httpx.Limits(keepalive_expiry=KEEPALIVE_EXPIRY)
54+
# Keepalive is set on the connector, not in ClientTimeout
5255
_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context()
5356

5457

@@ -59,7 +62,7 @@ def wrapped(*args, **kwargs):
5962
try:
6063
return func(*args, **kwargs)
6164
except Exception as err:
62-
raise ONVIFError(err)
65+
raise ONVIFError(err) from err
6366

6467
return wrapped
6568

@@ -102,20 +105,28 @@ def original_load(self, *args: Any, **kwargs: Any) -> None:
102105
return original_load(self, *args, **kwargs)
103106

104107

105-
class AsyncTransportProtocolErrorHandler(AsyncTransport):
106-
"""Retry on remote protocol error.
108+
class AsyncTransportProtocolErrorHandler(AIOHTTPTransport):
109+
"""
110+
Retry on remote protocol error.
107111
108112
http://datatracker.ietf.org/doc/html/rfc2616#section-8.1.4 allows the server
109113
# to close the connection at any time, we treat this as a normal and try again
110114
# once since
111115
"""
112116

113-
@retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError)
114-
async def post(self, address, message, headers):
117+
@retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError)
118+
async def post(
119+
self, address: str, message: str, headers: dict[str, str]
120+
) -> httpx.Response:
115121
return await super().post(address, message, headers)
116122

117-
@retry_connection_error(attempts=2, exception=httpx.RemoteProtocolError)
118-
async def get(self, address, params, headers):
123+
@retry_connection_error(attempts=2, exception=aiohttp.ServerDisconnectedError)
124+
async def get(
125+
self,
126+
address: str,
127+
params: dict[str, Any] | None = None,
128+
headers: dict[str, str] | None = None,
129+
) -> Response:
119130
return await super().get(address, params, headers)
120131

121132

@@ -162,17 +173,18 @@ def __init__(self, *args, **kwargs):
162173
self.set_ns_prefix("wsa", "http://www.w3.org/2005/08/addressing")
163174

164175
def create_service(self, binding_name, address):
165-
"""Create a new ServiceProxy for the given binding name and address.
176+
"""
177+
Create a new ServiceProxy for the given binding name and address.
166178
:param binding_name: The QName of the binding
167179
:param address: The address of the endpoint
168180
"""
169181
try:
170182
binding = self.wsdl.bindings[binding_name]
171183
except KeyError:
172184
raise ValueError(
173-
"No binding found with the given QName. Available bindings "
174-
"are: %s" % (", ".join(self.wsdl.bindings.keys()))
175-
)
185+
f"No binding found with the given QName. Available bindings "
186+
f"are: {', '.join(self.wsdl.bindings.keys())}"
187+
) from None
176188
return AsyncServiceProxy(self, binding, address=address)
177189

178190

@@ -223,7 +235,7 @@ def __init__(
223235
write_timeout: int | None = None,
224236
) -> None:
225237
if not path_isfile(url):
226-
raise ONVIFError("%s doesn`t exist!" % url)
238+
raise ONVIFError(f"{url} doesn`t exist!")
227239

228240
self.url = url
229241
self.xaddr = xaddr
@@ -236,26 +248,28 @@ def __init__(
236248
self.dt_diff = dt_diff
237249
self.binding_name = binding_name
238250
# Create soap client
239-
timeouts = httpx.Timeout(
240-
_DEFAULT_TIMEOUT,
241-
connect=_CONNECT_TIMEOUT,
242-
read=read_timeout or _READ_TIMEOUT,
243-
write=write_timeout or _WRITE_TIMEOUT,
244-
)
245-
client = AsyncClient(
246-
verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS
251+
connector = TCPConnector(
252+
ssl=_NO_VERIFY_SSL_CONTEXT,
253+
keepalive_timeout=KEEPALIVE_EXPIRY,
247254
)
248-
# The wsdl client should never actually be used, but it is required
249-
# to avoid creating another ssl context since the underlying code
250-
# will try to create a new one if it doesn't exist.
251-
wsdl_client = httpx.Client(
252-
verify=_NO_VERIFY_SSL_CONTEXT, timeout=timeouts, limits=_HTTPX_LIMITS
255+
session = ClientSession(
256+
connector=connector,
257+
timeout=aiohttp.ClientTimeout(
258+
total=_DEFAULT_TIMEOUT,
259+
connect=_CONNECT_TIMEOUT,
260+
sock_read=read_timeout or _READ_TIMEOUT,
261+
),
253262
)
254263
self.transport = (
255-
AsyncTransportProtocolErrorHandler(client=client, wsdl_client=wsdl_client)
264+
AsyncTransportProtocolErrorHandler(
265+
session=session,
266+
verify_ssl=False,
267+
)
256268
if no_cache
257-
else AsyncTransportProtocolErrorHandler(
258-
client=client, wsdl_client=wsdl_client, cache=SqliteCache()
269+
else AIOHTTPTransport(
270+
session=session,
271+
verify_ssl=False,
272+
cache=SqliteCache(),
259273
)
260274
)
261275
self.document: Document | None = None
@@ -399,7 +413,8 @@ def __init__(
399413
self.to_dict = ONVIFService.to_dict
400414

401415
self._snapshot_uris = {}
402-
self._snapshot_client = AsyncClient(verify=_NO_VERIFY_SSL_CONTEXT)
416+
self._snapshot_connector = TCPConnector(ssl=_NO_VERIFY_SSL_CONTEXT)
417+
self._snapshot_client = ClientSession(connector=self._snapshot_connector)
403418

404419
async def get_capabilities(self) -> dict[str, Any]:
405420
"""Get device capabilities."""
@@ -531,7 +546,8 @@ async def create_notification_manager(
531546

532547
async def close(self) -> None:
533548
"""Close all transports."""
534-
await self._snapshot_client.aclose()
549+
await self._snapshot_client.close()
550+
await self._snapshot_connector.close()
535551
for service in self.services.values():
536552
await service.close()
537553

@@ -572,42 +588,53 @@ async def get_snapshot(
572588
if uri is None:
573589
return None
574590

575-
auth = None
591+
auth: BasicAuth | None = None
592+
middlewares: tuple[DigestAuthMiddleware, ...] | None = None
593+
576594
if self.user and self.passwd:
577595
if basic_auth:
578596
auth = BasicAuth(self.user, self.passwd)
579597
else:
580-
auth = DigestAuth(self.user, self.passwd)
598+
# Use DigestAuthMiddleware for digest auth
599+
middlewares = (DigestAuthMiddleware(self.user, self.passwd),)
581600

582-
response = await self._try_snapshot_uri(uri, auth)
601+
response = await self._try_snapshot_uri(uri, auth=auth, middlewares=middlewares)
602+
content = await response.read()
583603

584-
# If the request fails with a 401, make sure to strip any
585-
# sample user/pass from the URL and try again
604+
# If the request fails with a 401, strip user/pass from URL and retry
586605
if (
587-
response.status_code == 401
606+
response.status == 401
588607
and (stripped_uri := strip_user_pass_url(uri))
589608
and stripped_uri != uri
590609
):
591-
response = await self._try_snapshot_uri(stripped_uri, auth)
610+
response = await self._try_snapshot_uri(
611+
stripped_uri, auth=auth, middlewares=middlewares
612+
)
613+
content = await response.read()
592614

593-
if response.status_code == 401:
615+
if response.status == 401:
594616
raise ONVIFAuthError(f"Failed to authenticate to {uri}")
595617

596-
if response.status_code < 300:
597-
return response.content
618+
if response.status < 300:
619+
return content
598620

599621
return None
600622

601623
async def _try_snapshot_uri(
602-
self, uri: str, auth: BasicAuth | DigestAuth | None
603-
) -> httpx.Response:
624+
self,
625+
uri: str,
626+
auth: BasicAuth | None = None,
627+
middlewares: tuple[DigestAuthMiddleware, ...] | None = None,
628+
) -> aiohttp.ClientResponse:
604629
try:
605-
return await self._snapshot_client.get(uri, auth=auth)
606-
except httpx.TimeoutException as error:
630+
return await self._snapshot_client.get(
631+
uri, auth=auth, middlewares=middlewares
632+
)
633+
except TimeoutError as error:
607634
raise ONVIFTimeoutError(
608635
f"Timed out fetching {obscure_user_pass_url(uri)}: {error}"
609636
) from error
610-
except httpx.RequestError as error:
637+
except aiohttp.ClientError as error:
611638
raise ONVIFError(
612639
f"Error fetching {obscure_user_pass_url(uri)}: {error}"
613640
) from error
@@ -618,7 +645,7 @@ def get_definition(
618645
"""Returns xaddr and wsdl of specified service"""
619646
# Check if the service is supported
620647
if name not in SERVICES:
621-
raise ONVIFError("Unknown service %s" % name)
648+
raise ONVIFError(f"Unknown service {name}")
622649
wsdl_file = SERVICES[name]["wsdl"]
623650
namespace = SERVICES[name]["ns"]
624651

@@ -629,14 +656,14 @@ def get_definition(
629656

630657
wsdlpath = os.path.join(self.wsdl_dir, wsdl_file)
631658
if not path_isfile(wsdlpath):
632-
raise ONVIFError("No such file: %s" % wsdlpath)
659+
raise ONVIFError(f"No such file: {wsdlpath}")
633660

634661
# XAddr for devicemgmt is fixed:
635662
if name == "devicemgmt":
636663
xaddr = "{}:{}/onvif/device_service".format(
637664
self.host
638665
if (self.host.startswith("http://") or self.host.startswith("https://"))
639-
else "http://%s" % self.host,
666+
else f"http://{self.host}",
640667
self.port,
641668
)
642669
return xaddr, wsdlpath, binding_name

onvif/managers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
from __future__ import annotations
44

5-
from abc import abstractmethod
65
import asyncio
76
import datetime as dt
87
import logging
9-
from typing import TYPE_CHECKING, Any
8+
from abc import abstractmethod
109
from collections.abc import Callable
10+
from typing import TYPE_CHECKING, Any
1111

12-
import httpx
13-
from httpx import TransportError
1412
from zeep.exceptions import Fault, XMLParseError, XMLSyntaxError
1513
from zeep.loader import parse_xml
1614
from zeep.wsdl.bindings.soap import SoapOperation
1715

16+
import aiohttp
1817
from onvif.exceptions import ONVIFError
1918

2019
from .settings import DEFAULT_SETTINGS
@@ -27,8 +26,8 @@
2726

2827
_RENEWAL_PERCENTAGE = 0.8
2928

30-
SUBSCRIPTION_ERRORS = (Fault, asyncio.TimeoutError, TransportError)
31-
RENEW_ERRORS = (ONVIFError, httpx.RequestError, XMLParseError, *SUBSCRIPTION_ERRORS)
29+
SUBSCRIPTION_ERRORS = (Fault, asyncio.TimeoutError, aiohttp.ClientError)
30+
RENEW_ERRORS = (ONVIFError, aiohttp.ClientError, XMLParseError, *SUBSCRIPTION_ERRORS)
3231
SUBSCRIPTION_RESTART_INTERVAL_ON_ERROR = dt.timedelta(seconds=40)
3332

3433
# If the camera returns a subscription with a termination time that is less than
@@ -87,7 +86,8 @@ async def stop(self) -> None:
8786
await self._subscription.Unsubscribe()
8887

8988
async def shutdown(self) -> None:
90-
"""Shutdown the manager.
89+
"""
90+
Shutdown the manager.
9191
9292
This method is irreversible.
9393
"""
@@ -105,7 +105,7 @@ async def set_synchronization_point(self) -> float:
105105
"""Set the synchronization point."""
106106
try:
107107
await self._service.SetSynchronizationPoint()
108-
except (Fault, asyncio.TimeoutError, TransportError, TypeError):
108+
except (TimeoutError, Fault, aiohttp.ClientError, TypeError):
109109
logger.debug("%s: SetSynchronizationPoint failed", self._service.url)
110110

111111
def _cancel_renewals(self) -> None:
@@ -214,7 +214,8 @@ def __init__(
214214
super().__init__(device, interval, subscription_lost_callback)
215215

216216
async def _start(self) -> float:
217-
"""Start the notification processor.
217+
"""
218+
Start the notification processor.
218219
219220
Returns the next renewal call at time.
220221
"""
@@ -290,7 +291,8 @@ class PullPointManager(BaseManager):
290291
"""Manager for PullPoint."""
291292

292293
async def _start(self) -> float:
293-
"""Start the PullPoint manager.
294+
"""
295+
Start the PullPoint manager.
294296
295297
Returns the next renewal call at time.
296298
"""

0 commit comments

Comments
 (0)