Skip to content

Commit 3036fdd

Browse files
committed
edits.
1 parent 9146e54 commit 3036fdd

File tree

2 files changed

+68
-50
lines changed

2 files changed

+68
-50
lines changed

django_mongodb_backend/cache.py

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pickle
22
from datetime import datetime, timezone
33

4+
from bson import SON
45
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
56
from django.db import connections, router
67
from django.utils.functional import cached_property
8+
from pymongo.errors import DuplicateKeyError
79

810

911
class MongoSerializer:
@@ -58,39 +60,11 @@ class MongoDBCache(BaseDatabaseCache):
5860
def __init__(self, *args, **options):
5961
super().__init__(*args, **options)
6062
# don't know If I can set the capped collection here.
61-
coll_info = self.collection.options()
62-
collections = set(self._db.database.list_collection_names())
63-
coll_exists = self._collection_name in collections
64-
if coll_exists and not coll_info.get("capped", False):
65-
self._db.database.command(
66-
"convertToCapped", self._collection_name, size=self._max_entries
67-
)
68-
elif coll_exists and coll_info.get("size") != self._max_entries:
69-
new_coll = self._copy_collection()
70-
self.collection.drop()
71-
new_coll.rename(self._collection_name)
72-
self.create_indexes()
7363

7464
def create_indexes(self):
75-
self.collection.create_index("expire_at", expireAfterSeconds=0)
65+
self.collection.create_index("expires_at", expireAfterSeconds=0)
7666
self.collection.create_index("key", unique=True)
7767

78-
def _copy_collection(self):
79-
collection_name = self._get_tmp_collection_name()
80-
self.collection.aggregate([{"$out": collection_name}])
81-
return self._db.get_collection(collection_name)
82-
83-
def _get_tmp_collection_name(self):
84-
collections = set(self._db.database.list_collection_names())
85-
template_collection_name = "tmp__collection__{num}"
86-
num = 0
87-
while True:
88-
tmp_collection_name = template_collection_name.format(num=num)
89-
if tmp_collection_name not in collections:
90-
break
91-
num += 1
92-
return tmp_collection_name
93-
9468
@cached_property
9569
def serializer(self):
9670
return MongoSerializer()
@@ -104,29 +78,38 @@ def collection(self):
10478
return self._db.get_collection(self._collection_name)
10579

10680
def get(self, key, default=None, version=None):
107-
key = self.make_and_validate_key(key, version=version)
108-
result = self.collection.find_one({"key": key})
109-
if result is not None:
110-
return self.serializer.loads(result["value"])
81+
result = self.get_many([key], version)
82+
if result:
83+
return result[key]
11184
return default
11285

86+
def _filter_expired(self, expired=False):
87+
not_expired_filter = [{"expires_at": {"$gte": datetime.utcnow()}}, {"expires_at": None}]
88+
operator = "$nor" if expired else "$or"
89+
return {operator: not_expired_filter}
90+
11391
def get_many(self, keys, version=None):
11492
if not keys:
11593
return {}
11694
keys_map = {self.make_and_validate_key(key, version=version): key for key in keys}
117-
with self.collection.find({"key": {"$in": tuple(keys_map)}}) as cursor:
118-
return {keys_map[row["key"]]: row["value"] for row in cursor}
95+
with self.collection.find(
96+
{"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)}
97+
) as cursor:
98+
return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor}
11999

120100
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
121101
key = self.make_and_validate_key(key, version=version)
122102
serialized_data = self.serializer.dumps(value)
103+
num = self.collection.count_documents({})
104+
if num > self._max_entries:
105+
self._cull(num)
123106
return self.collection.update_one(
124107
{"key": key},
125108
{
126109
"$set": {
127110
"key": key,
128111
"value": serialized_data,
129-
"expire_at": self._get_expiration_time(timeout),
112+
"expires_at": self._get_expiration_time(timeout),
130113
}
131114
},
132115
True,
@@ -135,24 +118,54 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
135118
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
136119
key = self.make_and_validate_key(key, version=version)
137120
serialized_data = self.serializer.dumps(value)
121+
num = self.collection.count_documents({})
122+
if num > self._max_entries:
123+
self._cull(num)
138124
try:
139-
self.collection.insert_one(
125+
self.collection.update_one(
126+
{"key": key, **self._filter_expired(expired=True)},
140127
{
141-
"key": key,
142-
"value": serialized_data,
143-
"expire_at": self._get_expiration_time(timeout),
144-
}
128+
"$set": {
129+
"key": key,
130+
"value": serialized_data,
131+
"expires_at": self._get_expiration_time(timeout),
132+
}
133+
},
134+
True,
145135
)
146-
except Exception:
147-
# check the exception name to catch when the key exists
136+
except DuplicateKeyError:
137+
# Check the exception name to catch when the key exists.
148138
return False
149139
return True
150140

141+
def _cull(self, num):
142+
if self._cull_frequency == 0:
143+
self.clear()
144+
else:
145+
cull_num = num // self._cull_frequency
146+
try:
147+
# Delete the first expiration date.
148+
deleted_from = next(
149+
self.collection.aggregate(
150+
[
151+
{"$sort": SON([("expired_at", 1), ("key", 1)])},
152+
{"$skip": cull_num},
153+
{"$limit": 1},
154+
{"$project": {"key": 1}},
155+
]
156+
)
157+
)
158+
except StopIteration:
159+
pass
160+
else:
161+
self.collection.delete_many({"key": {"$lt": deleted_from["key"]}})
162+
151163
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
152164
key = self.make_and_validate_key(key, version=version)
153-
return self.collection.update_one(
154-
{"key": key}, {"$set": {"expire_at": self._get_expiration_time(timeout)}}
165+
res = self.collection.update_one(
166+
{"key": key}, {"$set": {"expires_at": self._get_expiration_time(timeout)}}
155167
)
168+
return res.matched_count > 0
156169

157170
def _get_expiration_time(self, timeout=None):
158171
timestamp = self.get_backend_timeout(timeout)
@@ -162,17 +175,22 @@ def _get_expiration_time(self, timeout=None):
162175
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
163176

164177
def delete(self, key, version=None):
165-
return self.delete_many([key], version)
178+
return self._delete_many([key], version)
166179

167180
def delete_many(self, keys, version=None):
181+
self._delete_many(keys, version)
182+
183+
def _delete_many(self, keys, version=None):
168184
if not keys:
169185
return False
170186
keys = [self.make_and_validate_key(key, version=version) for key in keys]
171187
return bool(self.collection.delete_many({"key": {"$in": tuple(keys)}}).deleted_count)
172188

173189
def has_key(self, key, version=None):
174190
key = self.make_and_validate_key(key, version=version)
175-
return self.collection.count_documents({"key": key}) > 0
191+
return (
192+
self.collection.count_documents({"key": key, **self._filter_expired(expired=False)}) > 0
193+
)
176194

177195
def clear(self):
178196
self.collection.delete_many({})

tests/cache_/tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def test_invalid_key_length(self):
538538
# memcached limits key length to 250.
539539
key = ("a" * 250) + "清"
540540
expected_warning = (
541-
"Cache key will cause errors if used with memcached: " f"{key} (longer than {250})"
541+
"Cache key will cause errors if used with memcached: " f"'{key}' (longer than 250)"
542542
)
543543
self._perform_invalid_key_test(key, expected_warning)
544544

@@ -938,11 +938,11 @@ class DBCacheTests(BaseCacheTests, TestCase):
938938
def setUp(self):
939939
# The super calls needs to happen first for the settings override.
940940
super().setUp()
941-
self.create_table()
941+
self.create_cache_collection()
942942
self.addCleanup(self.drop_collection)
943943

944944
def drop_collection(self):
945945
cache.collection.drop()
946946

947-
def create_table(self):
947+
def create_cache_collection(self):
948948
management.call_command("createcachecollection", verbosity=0)

0 commit comments

Comments
 (0)