1
1
import pickle
2
2
from datetime import datetime , timezone
3
3
4
+ from bson import SON
4
5
from django .core .cache .backends .base import DEFAULT_TIMEOUT , BaseCache
5
6
from django .db import connections , router
6
7
from django .utils .functional import cached_property
8
+ from pymongo .errors import DuplicateKeyError
7
9
8
10
9
11
class MongoSerializer :
@@ -58,39 +60,11 @@ class MongoDBCache(BaseDatabaseCache):
58
60
def __init__ (self , * args , ** options ):
59
61
super ().__init__ (* args , ** options )
60
62
# don't know If I can set the capped collection here.
61
- coll_info = self .collection .options ()
62
- collections = set (self ._db .database .list_collection_names ())
63
- coll_exists = self ._collection_name in collections
64
- if coll_exists and not coll_info .get ("capped" , False ):
65
- self ._db .database .command (
66
- "convertToCapped" , self ._collection_name , size = self ._max_entries
67
- )
68
- elif coll_exists and coll_info .get ("size" ) != self ._max_entries :
69
- new_coll = self ._copy_collection ()
70
- self .collection .drop ()
71
- new_coll .rename (self ._collection_name )
72
- self .create_indexes ()
73
63
74
64
def create_indexes (self ):
75
- self .collection .create_index ("expire_at " , expireAfterSeconds = 0 )
65
+ self .collection .create_index ("expires_at " , expireAfterSeconds = 0 )
76
66
self .collection .create_index ("key" , unique = True )
77
67
78
- def _copy_collection (self ):
79
- collection_name = self ._get_tmp_collection_name ()
80
- self .collection .aggregate ([{"$out" : collection_name }])
81
- return self ._db .get_collection (collection_name )
82
-
83
- def _get_tmp_collection_name (self ):
84
- collections = set (self ._db .database .list_collection_names ())
85
- template_collection_name = "tmp__collection__{num}"
86
- num = 0
87
- while True :
88
- tmp_collection_name = template_collection_name .format (num = num )
89
- if tmp_collection_name not in collections :
90
- break
91
- num += 1
92
- return tmp_collection_name
93
-
94
68
@cached_property
95
69
def serializer (self ):
96
70
return MongoSerializer ()
@@ -104,29 +78,38 @@ def collection(self):
104
78
return self ._db .get_collection (self ._collection_name )
105
79
106
80
def get (self , key , default = None , version = None ):
107
- key = self .make_and_validate_key (key , version = version )
108
- result = self .collection .find_one ({"key" : key })
109
- if result is not None :
110
- return self .serializer .loads (result ["value" ])
81
+ result = self .get_many ([key ], version )
82
+ if result :
83
+ return result [key ]
111
84
return default
112
85
86
+ def _filter_expired (self , expired = False ):
87
+ not_expired_filter = [{"expires_at" : {"$gte" : datetime .utcnow ()}}, {"expires_at" : None }]
88
+ operator = "$nor" if expired else "$or"
89
+ return {operator : not_expired_filter }
90
+
113
91
def get_many (self , keys , version = None ):
114
92
if not keys :
115
93
return {}
116
94
keys_map = {self .make_and_validate_key (key , version = version ): key for key in keys }
117
- with self .collection .find ({"key" : {"$in" : tuple (keys_map )}}) as cursor :
118
- return {keys_map [row ["key" ]]: row ["value" ] for row in cursor }
95
+ with self .collection .find (
96
+ {"key" : {"$in" : tuple (keys_map )}, ** self ._filter_expired (expired = False )}
97
+ ) as cursor :
98
+ return {keys_map [row ["key" ]]: self .serializer .loads (row ["value" ]) for row in cursor }
119
99
120
100
def set (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
121
101
key = self .make_and_validate_key (key , version = version )
122
102
serialized_data = self .serializer .dumps (value )
103
+ num = self .collection .count_documents ({})
104
+ if num > self ._max_entries :
105
+ self ._cull (num )
123
106
return self .collection .update_one (
124
107
{"key" : key },
125
108
{
126
109
"$set" : {
127
110
"key" : key ,
128
111
"value" : serialized_data ,
129
- "expire_at " : self ._get_expiration_time (timeout ),
112
+ "expires_at " : self ._get_expiration_time (timeout ),
130
113
}
131
114
},
132
115
True ,
@@ -135,24 +118,54 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
135
118
def add (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
136
119
key = self .make_and_validate_key (key , version = version )
137
120
serialized_data = self .serializer .dumps (value )
121
+ num = self .collection .count_documents ({})
122
+ if num > self ._max_entries :
123
+ self ._cull (num )
138
124
try :
139
- self .collection .insert_one (
125
+ self .collection .update_one (
126
+ {"key" : key , ** self ._filter_expired (expired = True )},
140
127
{
141
- "key" : key ,
142
- "value" : serialized_data ,
143
- "expire_at" : self ._get_expiration_time (timeout ),
144
- }
128
+ "$set" : {
129
+ "key" : key ,
130
+ "value" : serialized_data ,
131
+ "expires_at" : self ._get_expiration_time (timeout ),
132
+ }
133
+ },
134
+ True ,
145
135
)
146
- except Exception :
147
- # check the exception name to catch when the key exists
136
+ except DuplicateKeyError :
137
+ # Check the exception name to catch when the key exists.
148
138
return False
149
139
return True
150
140
141
+ def _cull (self , num ):
142
+ if self ._cull_frequency == 0 :
143
+ self .clear ()
144
+ else :
145
+ cull_num = num // self ._cull_frequency
146
+ try :
147
+ # Delete the first expiration date.
148
+ deleted_from = next (
149
+ self .collection .aggregate (
150
+ [
151
+ {"$sort" : SON ([("expired_at" , 1 ), ("key" , 1 )])},
152
+ {"$skip" : cull_num },
153
+ {"$limit" : 1 },
154
+ {"$project" : {"key" : 1 }},
155
+ ]
156
+ )
157
+ )
158
+ except StopIteration :
159
+ pass
160
+ else :
161
+ self .collection .delete_many ({"key" : {"$lt" : deleted_from ["key" ]}})
162
+
151
163
def touch (self , key , timeout = DEFAULT_TIMEOUT , version = None ):
152
164
key = self .make_and_validate_key (key , version = version )
153
- return self .collection .update_one (
154
- {"key" : key }, {"$set" : {"expire_at " : self ._get_expiration_time (timeout )}}
165
+ res = self .collection .update_one (
166
+ {"key" : key }, {"$set" : {"expires_at " : self ._get_expiration_time (timeout )}}
155
167
)
168
+ return res .matched_count > 0
156
169
157
170
def _get_expiration_time (self , timeout = None ):
158
171
timestamp = self .get_backend_timeout (timeout )
@@ -162,17 +175,22 @@ def _get_expiration_time(self, timeout=None):
162
175
return datetime .fromtimestamp (timestamp , tz = timezone .utc )
163
176
164
177
def delete (self , key , version = None ):
165
- return self .delete_many ([key ], version )
178
+ return self ._delete_many ([key ], version )
166
179
167
180
def delete_many (self , keys , version = None ):
181
+ self ._delete_many (keys , version )
182
+
183
+ def _delete_many (self , keys , version = None ):
168
184
if not keys :
169
185
return False
170
186
keys = [self .make_and_validate_key (key , version = version ) for key in keys ]
171
187
return bool (self .collection .delete_many ({"key" : {"$in" : tuple (keys )}}).deleted_count )
172
188
173
189
def has_key (self , key , version = None ):
174
190
key = self .make_and_validate_key (key , version = version )
175
- return self .collection .count_documents ({"key" : key }) > 0
191
+ return (
192
+ self .collection .count_documents ({"key" : key , ** self ._filter_expired (expired = False )}) > 0
193
+ )
176
194
177
195
def clear (self ):
178
196
self .collection .delete_many ({})
0 commit comments