1- from typing import Optional , Set , Any , Dict
1+ from typing import Optional , Set
22import logging
33import json
44from datetime import datetime
5+ from base64 import b64encode , b64decode
56
67import boto3
78from boto3 .dynamodb .conditions import Key , Attr
@@ -117,48 +118,9 @@ def index_definition(index_name, keys, gsi=False):
117118 return schema
118119
119120
120- class DynamoIterableResult (IterableResult ):
121- def __init__ (self , cls , result , serialized_items ):
122- super (DynamoIterableResult , self ).__init__ (cls , serialized_items , result .get ("Count" ))
123-
124- self .last_evaluated_key = None
125- lsk = result .get ("LastEvaluatedKey" )
126- if lsk :
127- _key = [lsk [cls .Config .hash_key ]]
128- if cls .Config .range_key :
129- _key .append (lsk [cls .Config .range_key ])
130- self .last_evaluated_key = tuple (_key )
131-
132- self .scanned_count = result ["ScannedCount" ]
133-
134-
135- class Backend :
136- def __init__ (self , cls ):
137- cfg = cls .Config
138- self .cls = cls
139- self .schema = cls .schema ()
140- self .hash_key = cfg .hash_key
141- self .range_key = getattr (cfg , 'range_key' , None )
142- self .table_name = cls .get_table_name ()
143-
144- self .local_indexes = getattr (cfg , "local_indexes" , {})
145- self .global_indexes = getattr (cfg , "global_indexes" , {})
146- self .index_map = {(self .hash_key ,): None }
147- self .possible_keys = {self .hash_key }
148- if self .range_key :
149- self .possible_keys .add (self .range_key )
150- self .index_map = {(self .hash_key , self .range_key ): None }
151-
152- for name , keys in dict (** self .local_indexes , ** self .global_indexes ).items ():
153- self .index_map [keys ] = name
154- for key in keys :
155- self .possible_keys .add (key )
156-
157- self .dynamodb = boto3 .resource (
158- "dynamodb" ,
159- region_name = getattr (cfg , "region" , "us-east-2" ),
160- endpoint_url = getattr (cfg , "endpoint" , None ),
161- )
121+ class DynamoSerializer :
122+ def __init__ (self , schema ):
123+ self .schema = schema
162124
163125 def _serialize_field (self , field_name , value ):
164126 definition = self .schema .get ("definitions" )
@@ -175,7 +137,7 @@ def _serialize_field(self, field_name, value):
175137 log .debug (f"No serializer for field_type { field_type } " )
176138 return value # do nothing but log it.
177139
178- def _serialize_record (self , data_dict ) -> dict :
140+ def serialize_record (self , data_dict ) -> dict :
179141 """
180142 Apply converters to non-native types
181143 """
@@ -198,7 +160,7 @@ def _deserialize_field(self, field_name, value):
198160 log .debug (f"No deserializer for field_type { field_type } " )
199161 return value # do nothing but log it.
200162
201- def _deserialize_record (self , data_dict ) -> dict :
163+ def deserialize_record (self , data_dict ) -> dict :
202164 """
203165 Apply converters to non-native types
204166 """
@@ -207,6 +169,44 @@ def _deserialize_record(self, data_dict) -> dict:
207169 for field_name , value in data_dict .items ()
208170 }
209171
172+
173+ class DynamoIterableResult (IterableResult ):
174+ def __init__ (self , cls , result , serialized_items ):
175+ super (DynamoIterableResult , self ).__init__ (cls , serialized_items , result .get ("Count" ))
176+
177+ self .last_evaluated_key = result .get ("LastEvaluatedKey" )
178+ self .scanned_count = result ["ScannedCount" ]
179+
180+
181+ class Backend :
182+ def __init__ (self , cls ):
183+ cfg = cls .Config
184+ self .cls = cls
185+ self .schema = cls .schema ()
186+ self .serializer = DynamoSerializer (self .schema )
187+ self .hash_key = cfg .hash_key
188+ self .range_key = getattr (cfg , 'range_key' , None )
189+ self .table_name = cls .get_table_name ()
190+
191+ self .local_indexes = getattr (cfg , "local_indexes" , {})
192+ self .global_indexes = getattr (cfg , "global_indexes" , {})
193+ self .index_map = {(self .hash_key ,): None }
194+ self .possible_keys = {self .hash_key }
195+ if self .range_key :
196+ self .possible_keys .add (self .range_key )
197+ self .index_map = {(self .hash_key , self .range_key ): None }
198+
199+ for name , keys in dict (** self .local_indexes , ** self .global_indexes ).items ():
200+ self .index_map [keys ] = name
201+ for key in keys :
202+ self .possible_keys .add (key )
203+
204+ self .dynamodb = boto3 .resource (
205+ "dynamodb" ,
206+ region_name = getattr (cfg , "region" , "us-east-2" ),
207+ endpoint_url = getattr (cfg , "endpoint" , None ),
208+ )
209+
210210 def _key_param_to_dict (self , key ):
211211 _key = {
212212 self .hash_key : key ,
@@ -295,7 +295,7 @@ def query(self,
295295 query_expr : Optional [Rule ] = None ,
296296 filter_expr : Optional [Rule ] = None ,
297297 limit : Optional [int ] = None ,
298- exclusive_start_key : Optional [tuple [ Any ] ] = None ,
298+ exclusive_start_key : Optional [str ] = None ,
299299 order : str = 'asc' ,
300300 ):
301301 table = self .get_table ()
@@ -306,7 +306,7 @@ def query(self,
306306 if limit :
307307 params ["Limit" ] = limit
308308 if exclusive_start_key :
309- params ["ExclusiveStartKey" ] = self . _key_param_to_dict ( exclusive_start_key )
309+ params ["ExclusiveStartKey" ] = exclusive_start_key
310310 if f_expr :
311311 params ["FilterExpression" ] = f_expr
312312
@@ -347,7 +347,7 @@ def query(self,
347347 return []
348348 raise e
349349
350- return DynamoIterableResult (self .cls , resp , (self ._deserialize_record (rec ) for rec in resp ["Items" ]))
350+ return DynamoIterableResult (self .cls , resp , (self .serializer . deserialize_record (rec ) for rec in resp ["Items" ]))
351351
352352 def get (self , key ):
353353 _key = self ._key_param_to_dict (key )
@@ -363,10 +363,10 @@ def get(self, key):
363363 _key = key
364364 raise DoesNotExist (f'{ self .table_name } "{ _key } " does not exist' )
365365
366- return self ._deserialize_record (resp ["Item" ])
366+ return self .serializer . deserialize_record (resp ["Item" ])
367367
368368 def save (self , item , condition : Optional [Rule ] = None ) -> bool :
369- data = self ._serialize_record (item .dict (by_alias = True ))
369+ data = self .serializer . serialize_record (item .dict (by_alias = True ))
370370
371371 try :
372372 if condition :
0 commit comments