@@ -13,10 +13,7 @@ def __init__(self, protocol=None):
1313 def dumps (self , obj ):
1414 if isinstance (obj , int ):
1515 return obj
16- try :
17- return pickle .dumps (obj , self .protocol )
18- except pickle .PickleError as ex :
19- raise ValueError ("Object cannot be pickled" ) from ex
16+ return pickle .dumps (obj , self .protocol )
2017
2118 def loads (self , data ):
2219 try :
@@ -60,19 +57,58 @@ class CacheEntry:
6057class MongoDBCache (BaseDatabaseCache ):
6158 def __init__ (self , * args , ** options ):
6259 super ().__init__ (* args , ** options )
60+ # don't know If I can set the capped collection here.
61+ coll_info = self .collection .options ()
62+ collections = set (self ._db .database .list_collection_names ())
63+ coll_exists = self ._collection_name in collections
64+ if coll_exists and not coll_info .get ("capped" , False ):
65+ self ._db .database .command (
66+ "convertToCapped" , self ._collection_name , size = self ._max_entries
67+ )
68+ elif coll_exists and coll_info .get ("size" ) != self ._max_entries :
69+ new_coll = self ._copy_collection ()
70+ self .collection .drop ()
71+ new_coll .rename (self ._collection_name )
72+ self .create_indexes ()
73+
74+ def create_indexes (self ):
75+ self .collection .create_index ("expire_at" , expireAfterSeconds = 0 )
76+ self .collection .create_index ("key" , unique = True )
77+
78+ def _copy_collection (self ):
79+ collection_name = self ._get_tmp_collection_name ()
80+ self .collection .aggregate ([{"$out" : collection_name }])
81+ return self ._db .get_collection (collection_name )
82+
83+ def _get_tmp_collection_name (self ):
84+ collections = set (self ._db .database .list_collection_names ())
85+ template_collection_name = "tmp__collection__{num}"
86+ num = 0
87+ while True :
88+ tmp_collection_name = template_collection_name .format (num = num )
89+ if tmp_collection_name not in collections :
90+ break
91+ num += 1
92+ return tmp_collection_name
6393
6494 @cached_property
6595 def serializer (self ):
6696 return MongoSerializer ()
6797
68- @cached_property
98+ @property
99+ def _db (self ):
100+ return connections [router .db_for_read (self .cache_model_class )]
101+
102+ @property
69103 def collection (self ):
70- db = router .db_for_read (self .cache_model_class )
71- return connections [db ].get_collection (self ._collection_name )
104+ return self ._db .get_collection (self ._collection_name )
72105
73106 def get (self , key , default = None , version = None ):
74107 key = self .make_and_validate_key (key , version = version )
75- return self .collection .find_one ({"key" : key }) or default
108+ result = self .collection .find_one ({"key" : key })
109+ if result is not None :
110+ return self .serializer .loads (result ["value" ])
111+ return default
76112
77113 def get_many (self , keys , version = None ):
78114 if not keys :
@@ -86,8 +122,14 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
86122 serialized_data = self .serializer .dumps (value )
87123 return self .collection .update_one (
88124 {"key" : key },
89- {"key" : key , "value" : serialized_data , "expire_at" : self ._get_expiration_time (timeout )},
90- {"upsert" : True },
125+ {
126+ "$set" : {
127+ "key" : key ,
128+ "value" : serialized_data ,
129+ "expire_at" : self ._get_expiration_time (timeout ),
130+ }
131+ },
132+ True ,
91133 )
92134
93135 def add (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
@@ -130,7 +172,7 @@ def delete_many(self, keys, version=None):
130172
131173 def has_key (self , key , version = None ):
132174 key = self .make_and_validate_key (key , version = version )
133- return self .collection .count ({"key" : key }) > 0
175+ return self .collection .count_documents ({"key" : key }) > 0
134176
135177 def clear (self ):
136178 self .collection .delete_many ({})
0 commit comments