8
8
import os .path
9
9
from typing import Any
10
10
from collections .abc import Callable
11
-
11
+ from yarl import URL
12
+ from multidict import CIMultiDict
12
13
import httpx
13
14
from httpx import AsyncClient , BasicAuth , DigestAuth
14
15
from zeep .cache import SqliteCache
44
45
_WRITE_TIMEOUT = 90
45
46
_HTTPX_LIMITS = httpx .Limits (keepalive_expiry = KEEPALIVE_EXPIRY )
46
47
_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context ()
48
+ _CREDENTIAL_KEYS = ("username" , "user" , "pass" , "password" )
49
+
50
+
51
+ def strip_user_pass_url (url : str ) -> str :
52
+ """Strip password from URL."""
53
+ parsed_url = URL (url )
54
+ query = parsed_url .query
55
+ new_query : CIMultiDict | None = None
56
+ for key in _CREDENTIAL_KEYS :
57
+ if key in query :
58
+ if new_query is None :
59
+ new_query = CIMultiDict (parsed_url .query )
60
+ new_query .popall (key )
61
+ parsed_url = parsed_url .with_query (new_query )
62
+ return str (parsed_url )
63
+
64
+
65
+ def obscure_user_pass_url (url : str ) -> str :
66
+ """Obscure user and password from URL."""
67
+ parsed_url = URL (url )
68
+ query = parsed_url .query
69
+ new_query : CIMultiDict | None = None
70
+ for key in _CREDENTIAL_KEYS :
71
+ if key in query :
72
+ if new_query is None :
73
+ new_query = CIMultiDict (parsed_url .query )
74
+ new_query .popall (key )
75
+ new_query [key ] = "********"
76
+ parsed_url = parsed_url .with_query (new_query )
77
+ return str (parsed_url )
47
78
48
79
49
80
def safe_func (func ):
@@ -573,12 +604,16 @@ async def get_snapshot(
573
604
else :
574
605
auth = DigestAuth (self .user , self .passwd )
575
606
576
- try :
577
- response = await self ._snapshot_client .get (uri , auth = auth )
578
- except httpx .TimeoutException as error :
579
- raise ONVIFTimeoutError (f"Timed out fetching { uri } : { error } " ) from error
580
- except httpx .RequestError as error :
581
- raise ONVIFError (f"Error fetching { uri } : { error } " ) from error
607
+ response = await self ._try_snapshot_uri (uri , auth )
608
+
609
+ # If the request fails with a 401, make sure to strip any
610
+ # sample user/pass from the URL and try again
611
+ if (
612
+ response .status_code == 401
613
+ and (stripped_uri := strip_user_pass_url (uri ))
614
+ and stripped_uri != uri
615
+ ):
616
+ response = await self ._try_snapshot_uri (stripped_uri , auth )
582
617
583
618
if response .status_code == 401 :
584
619
raise ONVIFAuthError (f"Failed to authenticate to { uri } " )
@@ -588,6 +623,20 @@ async def get_snapshot(
588
623
589
624
return None
590
625
626
+ async def _try_snapshot_uri (
627
+ self , uri : str , auth : BasicAuth | DigestAuth | None
628
+ ) -> httpx .Response :
629
+ try :
630
+ return await self ._snapshot_client .get (uri , auth = auth )
631
+ except httpx .TimeoutException as error :
632
+ raise ONVIFTimeoutError (
633
+ f"Timed out fetching { obscure_user_pass_url (uri )} : { error } "
634
+ ) from error
635
+ except httpx .RequestError as error :
636
+ raise ONVIFError (
637
+ f"Error fetching { obscure_user_pass_url (uri )} : { error } "
638
+ ) from error
639
+
591
640
def get_definition (
592
641
self , name : str , port_type : str | None = None
593
642
) -> tuple [str , str , str ]:
0 commit comments