11__all__ = [
22 "BaseConnection" ,
33 "BasicConnection" ,
4+ "JwtConnection" ,
5+ "JwtSuperuserConnection" ,
46]
57
68import json
911
1012import jwt
1113
14+ from arangoasync import errno , logger
1215from arangoasync .auth import Auth , JwtToken
1316from arangoasync .compression import CompressionManager , DefaultCompressionManager
1417from arangoasync .exceptions import (
@@ -55,25 +58,45 @@ def db_name(self) -> str:
5558 """Return the database name."""
5659 return self ._db_name
5760
58- def prep_response (self , request : Request , resp : Response ) -> Response :
59- """Prepare response for return.
61+ @staticmethod
62+ def raise_for_status (request : Request , resp : Response ) -> None :
63+ """Raise an exception based on the response.
6064
6165 Args:
6266 request (Request): Request object.
6367 resp (Response): Response object.
6468
65- Returns:
66- Response: Response object
67-
6869 Raises:
6970 ServerConnectionError: If the response status code is not successful.
7071 """
71- # TODO needs refactoring such that it does not throw
72- resp .is_success = 200 <= resp .status_code < 300
7372 if resp .status_code in {401 , 403 }:
7473 raise ServerConnectionError (resp , request , "Authentication failed." )
7574 if not resp .is_success :
7675 raise ServerConnectionError (resp , request , "Bad server response." )
76+
77+ @staticmethod
78+ def prep_response (request : Request , resp : Response ) -> Response :
79+ """Prepare response for return.
80+
81+ Args:
82+ request (Request): Request object.
83+ resp (Response): Response object.
84+
85+ Returns:
86+ Response: Response object
87+ """
88+ resp .is_success = 200 <= resp .status_code < 300
89+ if not resp .is_success :
90+ try :
91+ body = json .loads (resp .raw_body )
92+ except json .JSONDecodeError as e :
93+ logger .debug (
94+ f"Failed to decode response body: { e } (from request { request } )"
95+ )
96+ else :
97+ if body .get ("error" ) is True :
98+ resp .error_code = body .get ("errorNum" )
99+ resp .error_message = body .get ("errorMessage" )
77100 return resp
78101
79102 async def process_request (self , request : Request ) -> Response :
@@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response:
86109 Response: Response object.
87110
88111 Raises:
89- ConnectionAbortedError: If can't connect to host(s) within limit.
112+ ConnectionAbortedError: If it can't connect to host(s) within limit.
90113 """
91114
92115 host_index = self ._host_resolver .get_host_index ()
@@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response:
100123 ex_host_index = host_index
101124 host_index = self ._host_resolver .get_host_index ()
102125 if ex_host_index == host_index :
126+ # Force change host if the same host is selected
103127 self ._host_resolver .change_host ()
104128 host_index = self ._host_resolver .get_host_index ()
105129
@@ -117,8 +141,8 @@ async def ping(self) -> int:
117141 ServerConnectionError: If the response status code is not successful.
118142 """
119143 request = Request (method = Method .GET , endpoint = "/_api/collection" )
120- request .headers = {"abde" : "fghi" }
121144 resp = await self .send_request (request )
145+ self .raise_for_status (request , resp )
122146 return resp .status_code
123147
124148 @abstractmethod
@@ -257,15 +281,15 @@ async def refresh_token(self) -> None:
257281 if self ._auth is None :
258282 raise JWTRefreshError ("Auth must be provided to refresh the token." )
259283
260- data = json .dumps (
284+ auth_data = json .dumps (
261285 dict (username = self ._auth .username , password = self ._auth .password ),
262286 separators = ("," , ":" ),
263287 ensure_ascii = False ,
264288 )
265289 request = Request (
266290 method = Method .POST ,
267291 endpoint = "/_open/auth" ,
268- data = data .encode ("utf-8" ),
292+ data = auth_data .encode ("utf-8" ),
269293 )
270294
271295 try :
@@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response:
310334
311335 request .headers ["authorization" ] = self ._auth_header
312336
313- try :
314- resp = await self .process_request (request )
315- if (
316- resp .status_code == 401 # Unauthorized
317- and self ._token is not None
318- and self ._token .needs_refresh (self ._expire_leeway )
319- ):
320- await self .refresh_token ()
321- return await self .process_request (request ) # Retry with new token
322- except ServerConnectionError :
323- # TODO modify after refactoring of prep_response, so we can inspect response
337+ resp = await self .process_request (request )
338+ if (
339+ resp .status_code == errno .HTTP_UNAUTHORIZED
340+ and self ._token is not None
341+ and self ._token .needs_refresh (self ._expire_leeway )
342+ ):
343+ # If the token has expired, refresh it and retry the request
324344 await self .refresh_token ()
325- return await self .process_request (request ) # Retry with new token
345+ resp = await self .process_request (request )
346+ self .raise_for_status (request , resp )
347+ return resp
348+
349+
350+ class JwtSuperuserConnection (BaseConnection ):
351+ """Connection to a specific ArangoDB database, using superuser JWT.
352+
353+ The JWT token is not refreshed and (username and password) are not required.
354+
355+ Args:
356+ sessions (list): List of client sessions.
357+ host_resolver (HostResolver): Host resolver.
358+ http_client (HTTPClient): HTTP client.
359+ db_name (str): Database name.
360+ compression (CompressionManager | None): Compression manager.
361+ token (JwtToken | None): JWT token.
362+ """
363+
364+ def __init__ (
365+ self ,
366+ sessions : List [Any ],
367+ host_resolver : HostResolver ,
368+ http_client : HTTPClient ,
369+ db_name : str ,
370+ compression : Optional [CompressionManager ] = None ,
371+ token : Optional [JwtToken ] = None ,
372+ ) -> None :
373+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
374+ self ._expire_leeway : int = 0
375+ self ._token : Optional [JwtToken ] = None
376+ self ._auth_header : Optional [str ] = None
377+ self .token = token
378+
379+ @property
380+ def token (self ) -> Optional [JwtToken ]:
381+ """Get the JWT token.
382+
383+ Returns:
384+ JwtToken | None: JWT token.
385+ """
386+ return self ._token
387+
388+ @token .setter
389+ def token (self , token : Optional [JwtToken ]) -> None :
390+ """Set the JWT token.
391+
392+ Args:
393+ token (JwtToken | None): JWT token.
394+ Setting it to None will cause the token to be automatically
395+ refreshed on the next request, if auth information is provided.
396+ """
397+ self ._token = token
398+ self ._auth_header = f"bearer { self ._token .token } " if self ._token else None
399+
400+ async def send_request (self , request : Request ) -> Response :
401+ """Send an HTTP request to the ArangoDB server.
402+
403+ Args:
404+ request (Request): HTTP request.
405+
406+ Returns:
407+ Response: HTTP response
408+
409+ Raises:
410+ ArangoClientError: If an error occurred from the client side.
411+ ArangoServerError: If an error occurred from the server side.
412+ """
413+ if self ._auth_header is None :
414+ raise AuthHeaderError ("Failed to generate authorization header." )
415+ request .headers ["authorization" ] = self ._auth_header
416+
417+ resp = await self .process_request (request )
418+ self .raise_for_status (request , resp )
419+ return resp
0 commit comments