33 "BasicConnection" ,
44]
55
6+ import json
67from abc import ABC , abstractmethod
78from typing import Any , List , Optional
89
9- from arangoasync .auth import Auth
10+ import jwt
11+
12+ from arangoasync .auth import Auth , JwtToken
1013from arangoasync .compression import CompressionManager , DefaultCompressionManager
1114from arangoasync .exceptions import (
15+ AuthHeaderError ,
16+ ClientConnectionAbortedError ,
1217 ClientConnectionError ,
13- ConnectionAbortedError ,
18+ JWTRefreshError ,
1419 ServerConnectionError ,
1520)
1621from arangoasync .http import HTTPClient
@@ -63,6 +68,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
6368 Raises:
6469 ServerConnectionError: If the response status code is not successful.
6570 """
71+ # TODO needs refactoring such that it does not throw
6672 resp .is_success = 200 <= resp .status_code < 300
6773 if resp .status_code in {401 , 403 }:
6874 raise ServerConnectionError (resp , request , "Authentication failed." )
@@ -97,7 +103,7 @@ async def process_request(self, request: Request) -> Response:
97103 self ._host_resolver .change_host ()
98104 host_index = self ._host_resolver .get_host_index ()
99105
100- raise ConnectionAbortedError (
106+ raise ClientConnectionAbortedError (
101107 f"Can't connect to host(s) within limit ({ self ._host_resolver .max_tries } )"
102108 )
103109
@@ -111,6 +117,7 @@ async def ping(self) -> int:
111117 ServerConnectionError: If the response status code is not successful.
112118 """
113119 request = Request (method = Method .GET , endpoint = "/_api/collection" )
120+ request .headers = {"abde" : "fghi" }
114121 resp = await self .send_request (request )
115122 return resp .status_code
116123
@@ -154,7 +161,18 @@ def __init__(
154161 self ._auth = auth
155162
156163 async def send_request (self , request : Request ) -> Response :
157- """Send an HTTP request to the ArangoDB server."""
164+ """Send an HTTP request to the ArangoDB server.
165+
166+ Args:
167+ request (Request): HTTP request.
168+
169+ Returns:
170+ Response: HTTP response
171+
172+ Raises:
173+ ArangoClientError: If an error occurred from the client side.
174+ ArangoServerError: If an error occurred from the server side.
175+ """
158176 if request .data is not None and self ._compression .needs_compression (
159177 request .data
160178 ):
@@ -169,3 +187,139 @@ async def send_request(self, request: Request) -> Response:
169187 request .auth = self ._auth
170188
171189 return await self .process_request (request )
190+
191+
192+ class JwtConnection (BaseConnection ):
193+ """Connection to a specific ArangoDB database, using JWT authentication.
194+
195+ Providing login information (username and password), allows to refresh the JWT.
196+
197+ Args:
198+ sessions (list): List of client sessions.
199+ host_resolver (HostResolver): Host resolver.
200+ http_client (HTTPClient): HTTP client.
201+ db_name (str): Database name.
202+ compression (CompressionManager | None): Compression manager.
203+ auth (Auth | None): Authentication information.
204+ token (JwtToken | None): JWT token.
205+
206+ Raises:
207+ ValueError: If neither token nor auth is provided.
208+ """
209+
210+ def __init__ (
211+ self ,
212+ sessions : List [Any ],
213+ host_resolver : HostResolver ,
214+ http_client : HTTPClient ,
215+ db_name : str ,
216+ compression : Optional [CompressionManager ] = None ,
217+ auth : Optional [Auth ] = None ,
218+ token : Optional [JwtToken ] = None ,
219+ ) -> None :
220+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
221+ self ._auth = auth
222+ self ._expire_leeway : int = 0
223+ self ._token : Optional [JwtToken ] = None
224+ self ._auth_header : Optional [str ] = None
225+ self .token = token
226+
227+ if self ._token is None and self ._auth is None :
228+ raise ValueError ("Either token or auth must be provided." )
229+
230+ @property
231+ def token (self ) -> Optional [JwtToken ]:
232+ """Get the JWT token.
233+
234+ Returns:
235+ JwtToken | None: JWT token.
236+ """
237+ return self ._token
238+
239+ @token .setter
240+ def token (self , token : Optional [JwtToken ]) -> None :
241+ """Set the JWT token.
242+
243+ Args:
244+ token (JwtToken | None): JWT token.
245+ Setting it to None will cause the token to be automatically
246+ refreshed on the next request, if auth information is provided.
247+ """
248+ self ._token = token
249+ self ._auth_header = f"bearer { self ._token .token } " if self ._token else None
250+
251+ async def refresh_token (self ) -> None :
252+ """Refresh the JWT token.
253+
254+ Raises:
255+ JWTRefreshError: If the token can't be refreshed.
256+ """
257+ if self ._auth is None :
258+ raise JWTRefreshError ("Auth must be provided to refresh the token." )
259+
260+ data = json .dumps (
261+ dict (username = self ._auth .username , password = self ._auth .password ),
262+ separators = ("," , ":" ),
263+ ensure_ascii = False ,
264+ )
265+ request = Request (
266+ method = Method .POST ,
267+ endpoint = "/_open/auth" ,
268+ data = data .encode ("utf-8" ),
269+ )
270+
271+ try :
272+ resp = await self .process_request (request )
273+ except ClientConnectionAbortedError as e :
274+ raise JWTRefreshError (str (e )) from e
275+ except ServerConnectionError as e :
276+ raise JWTRefreshError (str (e )) from e
277+
278+ if not resp .is_success :
279+ raise JWTRefreshError (
280+ f"Failed to refresh the JWT token: "
281+ f"{ resp .status_code } { resp .status_text } "
282+ )
283+
284+ token = json .loads (resp .raw_body )
285+ try :
286+ self .token = JwtToken (token ["jwt" ])
287+ except jwt .ExpiredSignatureError as e :
288+ raise JWTRefreshError (
289+ "Failed to refresh the JWT token: got an expired token"
290+ ) from e
291+
292+ async def send_request (self , request : Request ) -> Response :
293+ """Send an HTTP request to the ArangoDB server.
294+
295+ Args:
296+ request (Request): HTTP request.
297+
298+ Returns:
299+ Response: HTTP response
300+
301+ Raises:
302+ ArangoClientError: If an error occurred from the client side.
303+ ArangoServerError: If an error occurred from the server side.
304+ """
305+ if self ._auth_header is None :
306+ await self .refresh_token ()
307+
308+ if self ._auth_header is None :
309+ raise AuthHeaderError ("Failed to generate authorization header." )
310+
311+ request .headers ["authorization" ] = self ._auth_header
312+
313+ try :
314+ resp = await self .process_request (request )
315+ if (
316+ resp .status_code == 401 # Unauthorized
317+ and self ._token is not None
318+ and self ._token .needs_refresh (self ._expire_leeway )
319+ ):
320+ await self .refresh_token ()
321+ return await self .process_request (request ) # Retry with new token
322+ except ServerConnectionError :
323+ # TODO modify after refactoring of prep_response, so we can inspect response
324+ await self .refresh_token ()
325+ return await self .process_request (request ) # Retry with new token
0 commit comments