66 "JwtSuperuserConnection" ,
77]
88
9- import json
109from abc import ABC , abstractmethod
10+ from json import JSONDecodeError
1111from typing import Any , List , Optional
1212
13- import jwt
13+ from jwt import ExpiredSignatureError
1414
1515from arangoasync import errno , logger
1616from arangoasync .auth import Auth , JwtToken
2626from arangoasync .request import Method , Request
2727from arangoasync .resolver import HostResolver
2828from arangoasync .response import Response
29+ from arangoasync .serialization import (
30+ DefaultDeserializer ,
31+ DefaultSerializer ,
32+ Deserializer ,
33+ Serializer ,
34+ )
2935
3036
3137class BaseConnection (ABC ):
@@ -37,6 +43,10 @@ class BaseConnection(ABC):
3743 http_client (HTTPClient): HTTP client.
3844 db_name (str): Database name.
3945 compression (CompressionManager | None): Compression manager.
46+ serializer (Serializer | None): For custom serialization.
47+ Leave `None` for default.
48+ deserializer (Deserializer | None): For custom deserialization.
49+ Leave `None` for default.
4050 """
4151
4252 def __init__ (
@@ -46,19 +56,33 @@ def __init__(
4656 http_client : HTTPClient ,
4757 db_name : str ,
4858 compression : Optional [CompressionManager ] = None ,
59+ serializer : Optional [Serializer ] = None ,
60+ deserializer : Optional [Deserializer ] = None ,
4961 ) -> None :
5062 self ._sessions = sessions
5163 self ._db_endpoint = f"/_db/{ db_name } "
5264 self ._host_resolver = host_resolver
5365 self ._http_client = http_client
5466 self ._db_name = db_name
5567 self ._compression = compression
68+ self ._serializer = serializer or DefaultSerializer ()
69+ self ._deserializer = deserializer or DefaultDeserializer ()
5670
5771 @property
5872 def db_name (self ) -> str :
5973 """Return the database name."""
6074 return self ._db_name
6175
76+ @property
77+ def serializer (self ) -> Serializer :
78+ """Return the serializer."""
79+ return self ._serializer
80+
81+ @property
82+ def deserializer (self ) -> Deserializer :
83+ """Return the deserializer."""
84+ return self ._deserializer
85+
6286 @staticmethod
6387 def raise_for_status (request : Request , resp : Response ) -> None :
6488 """Raise an exception based on the response.
@@ -75,8 +99,7 @@ def raise_for_status(request: Request, resp: Response) -> None:
7599 if not resp .is_success :
76100 raise ServerConnectionError (resp , request , "Bad server response." )
77101
78- @staticmethod
79- def prep_response (request : Request , resp : Response ) -> Response :
102+ def prep_response (self , request : Request , resp : Response ) -> Response :
80103 """Prepare response for return.
81104
82105 Args:
@@ -89,8 +112,8 @@ def prep_response(request: Request, resp: Response) -> Response:
89112 resp .is_success = 200 <= resp .status_code < 300
90113 if not resp .is_success :
91114 try :
92- body = json . loads (resp .raw_body )
93- except json . JSONDecodeError as e :
115+ body = self . _deserializer . from_bytes (resp .raw_body )
116+ except JSONDecodeError as e :
94117 logger .debug (
95118 f"Failed to decode response body: { e } (from request { request } )"
96119 )
@@ -202,6 +225,8 @@ class BasicConnection(BaseConnection):
202225 http_client (HTTPClient): HTTP client.
203226 db_name (str): Database name.
204227 compression (CompressionManager | None): Compression manager.
228+ serializer (Serializer | None): For custom serialization.
229+ deserializer (Deserializer | None): For custom deserialization.
205230 auth (Auth | None): Authentication information.
206231 """
207232
@@ -212,9 +237,19 @@ def __init__(
212237 http_client : HTTPClient ,
213238 db_name : str ,
214239 compression : Optional [CompressionManager ] = None ,
240+ serializer : Optional [Serializer ] = None ,
241+ deserializer : Optional [Deserializer ] = None ,
215242 auth : Optional [Auth ] = None ,
216243 ) -> None :
217- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
244+ super ().__init__ (
245+ sessions ,
246+ host_resolver ,
247+ http_client ,
248+ db_name ,
249+ compression ,
250+ serializer ,
251+ deserializer ,
252+ )
218253 self ._auth = auth
219254
220255 async def send_request (self , request : Request ) -> Response :
@@ -249,6 +284,8 @@ class JwtConnection(BaseConnection):
249284 http_client (HTTPClient): HTTP client.
250285 db_name (str): Database name.
251286 compression (CompressionManager | None): Compression manager.
287+ serializer (Serializer | None): For custom serialization.
288+ deserializer (Deserializer | None): For custom deserialization.
252289 auth (Auth | None): Authentication information.
253290 token (JwtToken | None): JWT token.
254291
@@ -263,10 +300,20 @@ def __init__(
263300 http_client : HTTPClient ,
264301 db_name : str ,
265302 compression : Optional [CompressionManager ] = None ,
303+ serializer : Optional [Serializer ] = None ,
304+ deserializer : Optional [Deserializer ] = None ,
266305 auth : Optional [Auth ] = None ,
267306 token : Optional [JwtToken ] = None ,
268307 ) -> None :
269- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
308+ super ().__init__ (
309+ sessions ,
310+ host_resolver ,
311+ http_client ,
312+ db_name ,
313+ compression ,
314+ serializer ,
315+ deserializer ,
316+ )
270317 self ._auth = auth
271318 self ._expire_leeway : int = 0
272319 self ._token : Optional [JwtToken ] = token
@@ -306,10 +353,8 @@ async def refresh_token(self) -> None:
306353 if self ._auth is None :
307354 raise JWTRefreshError ("Auth must be provided to refresh the token." )
308355
309- auth_data = json . dumps (
356+ auth_data = self . _serializer . to_str (
310357 dict (username = self ._auth .username , password = self ._auth .password ),
311- separators = ("," , ":" ),
312- ensure_ascii = False ,
313358 )
314359 request = Request (
315360 method = Method .POST ,
@@ -330,10 +375,10 @@ async def refresh_token(self) -> None:
330375 f"{ resp .status_code } { resp .status_text } "
331376 )
332377
333- token = json . loads (resp .raw_body )
378+ token = self . _deserializer . from_bytes (resp .raw_body )
334379 try :
335380 self .token = JwtToken (token ["jwt" ])
336- except jwt . ExpiredSignatureError as e :
381+ except ExpiredSignatureError as e :
337382 raise JWTRefreshError (
338383 "Failed to refresh the JWT token: got an expired token"
339384 ) from e
@@ -385,6 +430,8 @@ class JwtSuperuserConnection(BaseConnection):
385430 http_client (HTTPClient): HTTP client.
386431 db_name (str): Database name.
387432 compression (CompressionManager | None): Compression manager.
433+ serializer (Serializer | None): For custom serialization.
434+ deserializer (Deserializer | None): For custom deserialization.
388435 token (JwtToken | None): JWT token.
389436 """
390437
@@ -395,10 +442,19 @@ def __init__(
395442 http_client : HTTPClient ,
396443 db_name : str ,
397444 compression : Optional [CompressionManager ] = None ,
445+ serializer : Optional [Serializer ] = None ,
446+ deserializer : Optional [Deserializer ] = None ,
398447 token : Optional [JwtToken ] = None ,
399448 ) -> None :
400- super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
401- self ._expire_leeway : int = 0
449+ super ().__init__ (
450+ sessions ,
451+ host_resolver ,
452+ http_client ,
453+ db_name ,
454+ compression ,
455+ serializer ,
456+ deserializer ,
457+ )
402458 self ._token : Optional [JwtToken ] = token
403459 self ._auth_header : Optional [str ] = None
404460 self .token = self ._token
0 commit comments