11import pickle
22from datetime import datetime , timezone
33
4+ from bson import SON
45from django .core .cache .backends .base import DEFAULT_TIMEOUT , BaseCache
56from django .db import connections , router
67from django .utils .functional import cached_property
8+ from pymongo .errors import DuplicateKeyError
79
810
911class MongoSerializer :
@@ -58,39 +60,11 @@ class MongoDBCache(BaseDatabaseCache):
5860 def __init__ (self , * args , ** options ):
5961 super ().__init__ (* args , ** options )
6062 # don't know If I can set the capped collection here.
61- coll_info = self .collection .options ()
62- collections = set (self ._db .database .list_collection_names ())
63- coll_exists = self ._collection_name in collections
64- if coll_exists and not coll_info .get ("capped" , False ):
65- self ._db .database .command (
66- "convertToCapped" , self ._collection_name , size = self ._max_entries
67- )
68- elif coll_exists and coll_info .get ("size" ) != self ._max_entries :
69- new_coll = self ._copy_collection ()
70- self .collection .drop ()
71- new_coll .rename (self ._collection_name )
72- self .create_indexes ()
7363
7464 def create_indexes (self ):
75- self .collection .create_index ("expire_at " , expireAfterSeconds = 0 )
65+ self .collection .create_index ("expires_at " , expireAfterSeconds = 0 )
7666 self .collection .create_index ("key" , unique = True )
7767
78- def _copy_collection (self ):
79- collection_name = self ._get_tmp_collection_name ()
80- self .collection .aggregate ([{"$out" : collection_name }])
81- return self ._db .get_collection (collection_name )
82-
83- def _get_tmp_collection_name (self ):
84- collections = set (self ._db .database .list_collection_names ())
85- template_collection_name = "tmp__collection__{num}"
86- num = 0
87- while True :
88- tmp_collection_name = template_collection_name .format (num = num )
89- if tmp_collection_name not in collections :
90- break
91- num += 1
92- return tmp_collection_name
93-
9468 @cached_property
9569 def serializer (self ):
9670 return MongoSerializer ()
@@ -104,29 +78,38 @@ def collection(self):
10478 return self ._db .get_collection (self ._collection_name )
10579
10680 def get (self , key , default = None , version = None ):
107- key = self .make_and_validate_key (key , version = version )
108- result = self .collection .find_one ({"key" : key })
109- if result is not None :
110- return self .serializer .loads (result ["value" ])
81+ result = self .get_many ([key ], version )
82+ if result :
83+ return result [key ]
11184 return default
11285
86+ def _filter_expired (self , expired = False ):
87+ not_expired_filter = [{"expires_at" : {"$gte" : datetime .utcnow ()}}, {"expires_at" : None }]
88+ operator = "$nor" if expired else "$or"
89+ return {operator : not_expired_filter }
90+
11391 def get_many (self , keys , version = None ):
11492 if not keys :
11593 return {}
11694 keys_map = {self .make_and_validate_key (key , version = version ): key for key in keys }
117- with self .collection .find ({"key" : {"$in" : tuple (keys_map )}}) as cursor :
118- return {keys_map [row ["key" ]]: row ["value" ] for row in cursor }
95+ with self .collection .find (
96+ {"key" : {"$in" : tuple (keys_map )}, ** self ._filter_expired (expired = False )}
97+ ) as cursor :
98+ return {keys_map [row ["key" ]]: self .serializer .loads (row ["value" ]) for row in cursor }
11999
120100 def set (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
121101 key = self .make_and_validate_key (key , version = version )
122102 serialized_data = self .serializer .dumps (value )
103+ num = self .collection .count_documents ({})
104+ if num > self ._max_entries :
105+ self ._cull (num )
123106 return self .collection .update_one (
124107 {"key" : key },
125108 {
126109 "$set" : {
127110 "key" : key ,
128111 "value" : serialized_data ,
129- "expire_at " : self ._get_expiration_time (timeout ),
112+ "expires_at " : self ._get_expiration_time (timeout ),
130113 }
131114 },
132115 True ,
@@ -135,24 +118,54 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
135118 def add (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
136119 key = self .make_and_validate_key (key , version = version )
137120 serialized_data = self .serializer .dumps (value )
121+ num = self .collection .count_documents ({})
122+ if num > self ._max_entries :
123+ self ._cull (num )
138124 try :
139- self .collection .insert_one (
125+ self .collection .update_one (
126+ {"key" : key , ** self ._filter_expired (expired = True )},
140127 {
141- "key" : key ,
142- "value" : serialized_data ,
143- "expire_at" : self ._get_expiration_time (timeout ),
144- }
128+ "$set" : {
129+ "key" : key ,
130+ "value" : serialized_data ,
131+ "expires_at" : self ._get_expiration_time (timeout ),
132+ }
133+ },
134+ True ,
145135 )
146- except Exception :
147- # check the exception name to catch when the key exists
136+ except DuplicateKeyError :
137+ # Check the exception name to catch when the key exists.
148138 return False
149139 return True
150140
141+ def _cull (self , num ):
142+ if self ._cull_frequency == 0 :
143+ self .clear ()
144+ else :
145+ cull_num = num // self ._cull_frequency
146+ try :
147+ # Delete the first expiration date.
148+ deleted_from = next (
149+ self .collection .aggregate (
150+ [
151+ {"$sort" : SON ([("expired_at" , 1 ), ("key" , 1 )])},
152+ {"$skip" : cull_num },
153+ {"$limit" : 1 },
154+ {"$project" : {"key" : 1 }},
155+ ]
156+ )
157+ )
158+ except StopIteration :
159+ pass
160+ else :
161+ self .collection .delete_many ({"key" : {"$lt" : deleted_from ["key" ]}})
162+
151163 def touch (self , key , timeout = DEFAULT_TIMEOUT , version = None ):
152164 key = self .make_and_validate_key (key , version = version )
153- return self .collection .update_one (
154- {"key" : key }, {"$set" : {"expire_at " : self ._get_expiration_time (timeout )}}
165+ res = self .collection .update_one (
166+ {"key" : key }, {"$set" : {"expires_at " : self ._get_expiration_time (timeout )}}
155167 )
168+ return res .matched_count > 0
156169
157170 def _get_expiration_time (self , timeout = None ):
158171 timestamp = self .get_backend_timeout (timeout )
@@ -162,17 +175,22 @@ def _get_expiration_time(self, timeout=None):
162175 return datetime .fromtimestamp (timestamp , tz = timezone .utc )
163176
164177 def delete (self , key , version = None ):
165- return self .delete_many ([key ], version )
178+ return self ._delete_many ([key ], version )
166179
167180 def delete_many (self , keys , version = None ):
181+ self ._delete_many (keys , version )
182+
183+ def _delete_many (self , keys , version = None ):
168184 if not keys :
169185 return False
170186 keys = [self .make_and_validate_key (key , version = version ) for key in keys ]
171187 return bool (self .collection .delete_many ({"key" : {"$in" : tuple (keys )}}).deleted_count )
172188
173189 def has_key (self , key , version = None ):
174190 key = self .make_and_validate_key (key , version = version )
175- return self .collection .count_documents ({"key" : key }) > 0
191+ return (
192+ self .collection .count_documents ({"key" : key , ** self ._filter_expired (expired = False )}) > 0
193+ )
176194
177195 def clear (self ):
178196 self .collection .delete_many ({})
0 commit comments