6
6
import datetime as dt
7
7
import logging
8
8
import os .path
9
- from typing import Any
10
9
from collections .abc import Callable
11
- import httpx
12
- from httpx import AsyncClient , BasicAuth , DigestAuth
10
+ from typing import Any
11
+
12
+ import zeep .helpers
13
13
from zeep .cache import SqliteCache
14
14
from zeep .client import AsyncClient as BaseZeepAsyncClient
15
- import zeep .helpers
16
15
from zeep .proxy import AsyncServiceProxy
17
- from zeep .transports import AsyncTransport
18
16
from zeep .wsdl import Document
19
17
from zeep .wsse .username import UsernameToken
20
18
19
+ import aiohttp
20
+ import httpx
21
+ from aiohttp import BasicAuth , ClientSession , DigestAuthMiddleware , TCPConnector
21
22
from onvif .definition import SERVICES
22
23
from onvif .exceptions import ONVIFAuthError , ONVIFError , ONVIFTimeoutError
24
+ from requests import Response
23
25
24
26
from .const import KEEPALIVE_EXPIRY
25
27
from .managers import NotificationManager , PullPointManager
29
31
from .util import (
30
32
create_no_verify_ssl_context ,
31
33
normalize_url ,
34
+ obscure_user_pass_url ,
32
35
path_isfile ,
33
- utcnow ,
34
36
strip_user_pass_url ,
35
- obscure_user_pass_url ,
37
+ utcnow ,
36
38
)
37
- from .wrappers import retry_connection_error # noqa: F401
39
+ from .wrappers import retry_connection_error
38
40
from .wsa import WsAddressingIfMissingPlugin
41
+ from .zeep_aiohttp import AIOHTTPTransport
39
42
40
43
logger = logging .getLogger ("onvif" )
41
44
logging .basicConfig (level = logging .INFO )
48
51
_CONNECT_TIMEOUT = 30
49
52
_READ_TIMEOUT = 90
50
53
_WRITE_TIMEOUT = 90
51
- _HTTPX_LIMITS = httpx . Limits ( keepalive_expiry = KEEPALIVE_EXPIRY )
54
+ # Keepalive is set on the connector, not in ClientTimeout
52
55
_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context ()
53
56
54
57
@@ -59,7 +62,7 @@ def wrapped(*args, **kwargs):
59
62
try :
60
63
return func (* args , ** kwargs )
61
64
except Exception as err :
62
- raise ONVIFError (err )
65
+ raise ONVIFError (err ) from err
63
66
64
67
return wrapped
65
68
@@ -102,20 +105,28 @@ def original_load(self, *args: Any, **kwargs: Any) -> None:
102
105
return original_load (self , * args , ** kwargs )
103
106
104
107
105
- class AsyncTransportProtocolErrorHandler (AsyncTransport ):
106
- """Retry on remote protocol error.
108
+ class AsyncTransportProtocolErrorHandler (AIOHTTPTransport ):
109
+ """
110
+ Retry on remote protocol error.
107
111
108
112
http://datatracker.ietf.org/doc/html/rfc2616#section-8.1.4 allows the server
109
113
# to close the connection at any time, we treat this as a normal and try again
110
114
# once since
111
115
"""
112
116
113
- @retry_connection_error (attempts = 2 , exception = httpx .RemoteProtocolError )
114
- async def post (self , address , message , headers ):
117
+ @retry_connection_error (attempts = 2 , exception = aiohttp .ServerDisconnectedError )
118
+ async def post (
119
+ self , address : str , message : str , headers : dict [str , str ]
120
+ ) -> httpx .Response :
115
121
return await super ().post (address , message , headers )
116
122
117
- @retry_connection_error (attempts = 2 , exception = httpx .RemoteProtocolError )
118
- async def get (self , address , params , headers ):
123
+ @retry_connection_error (attempts = 2 , exception = aiohttp .ServerDisconnectedError )
124
+ async def get (
125
+ self ,
126
+ address : str ,
127
+ params : dict [str , Any ] | None = None ,
128
+ headers : dict [str , str ] | None = None ,
129
+ ) -> Response :
119
130
return await super ().get (address , params , headers )
120
131
121
132
@@ -162,17 +173,18 @@ def __init__(self, *args, **kwargs):
162
173
self .set_ns_prefix ("wsa" , "http://www.w3.org/2005/08/addressing" )
163
174
164
175
def create_service (self , binding_name , address ):
165
- """Create a new ServiceProxy for the given binding name and address.
176
+ """
177
+ Create a new ServiceProxy for the given binding name and address.
166
178
:param binding_name: The QName of the binding
167
179
:param address: The address of the endpoint
168
180
"""
169
181
try :
170
182
binding = self .wsdl .bindings [binding_name ]
171
183
except KeyError :
172
184
raise ValueError (
173
- "No binding found with the given QName. Available bindings "
174
- "are: %s" % ( ", " .join (self .wsdl .bindings .keys ()))
175
- )
185
+ f "No binding found with the given QName. Available bindings "
186
+ f "are: { ', ' .join (self .wsdl .bindings .keys ())} "
187
+ ) from None
176
188
return AsyncServiceProxy (self , binding , address = address )
177
189
178
190
@@ -223,7 +235,7 @@ def __init__(
223
235
write_timeout : int | None = None ,
224
236
) -> None :
225
237
if not path_isfile (url ):
226
- raise ONVIFError ("%s doesn`t exist!" % url )
238
+ raise ONVIFError (f" { url } doesn`t exist!" )
227
239
228
240
self .url = url
229
241
self .xaddr = xaddr
@@ -236,26 +248,28 @@ def __init__(
236
248
self .dt_diff = dt_diff
237
249
self .binding_name = binding_name
238
250
# Create soap client
239
- timeouts = httpx .Timeout (
240
- _DEFAULT_TIMEOUT ,
241
- connect = _CONNECT_TIMEOUT ,
242
- read = read_timeout or _READ_TIMEOUT ,
243
- write = write_timeout or _WRITE_TIMEOUT ,
244
- )
245
- client = AsyncClient (
246
- verify = _NO_VERIFY_SSL_CONTEXT , timeout = timeouts , limits = _HTTPX_LIMITS
251
+ connector = TCPConnector (
252
+ ssl = _NO_VERIFY_SSL_CONTEXT ,
253
+ keepalive_timeout = KEEPALIVE_EXPIRY ,
247
254
)
248
- # The wsdl client should never actually be used, but it is required
249
- # to avoid creating another ssl context since the underlying code
250
- # will try to create a new one if it doesn't exist.
251
- wsdl_client = httpx .Client (
252
- verify = _NO_VERIFY_SSL_CONTEXT , timeout = timeouts , limits = _HTTPX_LIMITS
255
+ session = ClientSession (
256
+ connector = connector ,
257
+ timeout = aiohttp .ClientTimeout (
258
+ total = _DEFAULT_TIMEOUT ,
259
+ connect = _CONNECT_TIMEOUT ,
260
+ sock_read = read_timeout or _READ_TIMEOUT ,
261
+ ),
253
262
)
254
263
self .transport = (
255
- AsyncTransportProtocolErrorHandler (client = client , wsdl_client = wsdl_client )
264
+ AsyncTransportProtocolErrorHandler (
265
+ session = session ,
266
+ verify_ssl = False ,
267
+ )
256
268
if no_cache
257
- else AsyncTransportProtocolErrorHandler (
258
- client = client , wsdl_client = wsdl_client , cache = SqliteCache ()
269
+ else AIOHTTPTransport (
270
+ session = session ,
271
+ verify_ssl = False ,
272
+ cache = SqliteCache (),
259
273
)
260
274
)
261
275
self .document : Document | None = None
@@ -399,7 +413,8 @@ def __init__(
399
413
self .to_dict = ONVIFService .to_dict
400
414
401
415
self ._snapshot_uris = {}
402
- self ._snapshot_client = AsyncClient (verify = _NO_VERIFY_SSL_CONTEXT )
416
+ self ._snapshot_connector = TCPConnector (ssl = _NO_VERIFY_SSL_CONTEXT )
417
+ self ._snapshot_client = ClientSession (connector = self ._snapshot_connector )
403
418
404
419
async def get_capabilities (self ) -> dict [str , Any ]:
405
420
"""Get device capabilities."""
@@ -531,7 +546,8 @@ async def create_notification_manager(
531
546
532
547
async def close (self ) -> None :
533
548
"""Close all transports."""
534
- await self ._snapshot_client .aclose ()
549
+ await self ._snapshot_client .close ()
550
+ await self ._snapshot_connector .close ()
535
551
for service in self .services .values ():
536
552
await service .close ()
537
553
@@ -572,42 +588,53 @@ async def get_snapshot(
572
588
if uri is None :
573
589
return None
574
590
575
- auth = None
591
+ auth : BasicAuth | None = None
592
+ middlewares : tuple [DigestAuthMiddleware , ...] | None = None
593
+
576
594
if self .user and self .passwd :
577
595
if basic_auth :
578
596
auth = BasicAuth (self .user , self .passwd )
579
597
else :
580
- auth = DigestAuth (self .user , self .passwd )
598
+ # Use DigestAuthMiddleware for digest auth
599
+ middlewares = (DigestAuthMiddleware (self .user , self .passwd ),)
581
600
582
- response = await self ._try_snapshot_uri (uri , auth )
601
+ response = await self ._try_snapshot_uri (uri , auth = auth , middlewares = middlewares )
602
+ content = await response .read ()
583
603
584
- # If the request fails with a 401, make sure to strip any
585
- # sample user/pass from the URL and try again
604
+ # If the request fails with a 401, strip user/pass from URL and retry
586
605
if (
587
- response .status_code == 401
606
+ response .status == 401
588
607
and (stripped_uri := strip_user_pass_url (uri ))
589
608
and stripped_uri != uri
590
609
):
591
- response = await self ._try_snapshot_uri (stripped_uri , auth )
610
+ response = await self ._try_snapshot_uri (
611
+ stripped_uri , auth = auth , middlewares = middlewares
612
+ )
613
+ content = await response .read ()
592
614
593
- if response .status_code == 401 :
615
+ if response .status == 401 :
594
616
raise ONVIFAuthError (f"Failed to authenticate to { uri } " )
595
617
596
- if response .status_code < 300 :
597
- return response . content
618
+ if response .status < 300 :
619
+ return content
598
620
599
621
return None
600
622
601
623
async def _try_snapshot_uri (
602
- self , uri : str , auth : BasicAuth | DigestAuth | None
603
- ) -> httpx .Response :
624
+ self ,
625
+ uri : str ,
626
+ auth : BasicAuth | None = None ,
627
+ middlewares : tuple [DigestAuthMiddleware , ...] | None = None ,
628
+ ) -> aiohttp .ClientResponse :
604
629
try :
605
- return await self ._snapshot_client .get (uri , auth = auth )
606
- except httpx .TimeoutException as error :
630
+ return await self ._snapshot_client .get (
631
+ uri , auth = auth , middlewares = middlewares
632
+ )
633
+ except TimeoutError as error :
607
634
raise ONVIFTimeoutError (
608
635
f"Timed out fetching { obscure_user_pass_url (uri )} : { error } "
609
636
) from error
610
- except httpx . RequestError as error :
637
+ except aiohttp . ClientError as error :
611
638
raise ONVIFError (
612
639
f"Error fetching { obscure_user_pass_url (uri )} : { error } "
613
640
) from error
@@ -618,7 +645,7 @@ def get_definition(
618
645
"""Returns xaddr and wsdl of specified service"""
619
646
# Check if the service is supported
620
647
if name not in SERVICES :
621
- raise ONVIFError ("Unknown service %s" % name )
648
+ raise ONVIFError (f "Unknown service { name } " )
622
649
wsdl_file = SERVICES [name ]["wsdl" ]
623
650
namespace = SERVICES [name ]["ns" ]
624
651
@@ -629,14 +656,14 @@ def get_definition(
629
656
630
657
wsdlpath = os .path .join (self .wsdl_dir , wsdl_file )
631
658
if not path_isfile (wsdlpath ):
632
- raise ONVIFError ("No such file: %s" % wsdlpath )
659
+ raise ONVIFError (f "No such file: { wsdlpath } " )
633
660
634
661
# XAddr for devicemgmt is fixed:
635
662
if name == "devicemgmt" :
636
663
xaddr = "{}:{}/onvif/device_service" .format (
637
664
self .host
638
665
if (self .host .startswith ("http://" ) or self .host .startswith ("https://" ))
639
- else "http://%s" % self .host ,
666
+ else f "http://{ self .host } " ,
640
667
self .port ,
641
668
)
642
669
return xaddr , wsdlpath , binding_name
0 commit comments