|
5 | 5 | from django.core.cache.backends.db import Options
|
6 | 6 | from django.db import connections, router
|
7 | 7 | from django.utils.functional import cached_property
|
8 |
| -from pymongo import IndexModel |
9 |
| -from pymongo.errors import DuplicateKeyError |
| 8 | +from pymongo import IndexModel, ReturnDocument |
| 9 | +from pymongo.errors import DuplicateKeyError, OperationFailure |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class MongoSerializer:
|
13 | 13 | def __init__(self, protocol=None):
|
14 | 14 | self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
|
15 | 15 |
|
16 | 16 | def dumps(self, obj):
|
17 |
| - # Integers do not need serialization. |
18 |
| - if isinstance(obj, int): |
| 17 | + # Only skip pickling for integers, a int subclasses as bool should be |
| 18 | + # pickled. |
| 19 | + if type(obj) is int: # noqa: E721 |
19 | 20 | return obj
|
20 | 21 | return pickle.dumps(obj, self.protocol)
|
21 | 22 |
|
@@ -83,7 +84,7 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
|
83 | 84 | num = self.collection_for_write.count_documents({}, hint="_id_")
|
84 | 85 | if num >= self._max_entries:
|
85 | 86 | self._cull(num)
|
86 |
| - return self.collection_for_write.update_one( |
| 87 | + self.collection_for_write.update_one( |
87 | 88 | {"key": key},
|
88 | 89 | {
|
89 | 90 | "$set": {
|
@@ -158,6 +159,24 @@ def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
|
158 | 159 | )
|
159 | 160 | return res.matched_count > 0
|
160 | 161 |
|
| 162 | + def incr(self, key, delta=1, version=None): |
| 163 | + serialized_key = self.make_and_validate_key(key, version=version) |
| 164 | + |
| 165 | + try: |
| 166 | + updated = self.collection_for_write.find_one_and_update( |
| 167 | + {"key": serialized_key, **self._filter_expired(expired=False)}, |
| 168 | + { |
| 169 | + "$inc": {"value": delta}, |
| 170 | + }, |
| 171 | + return_document=ReturnDocument.AFTER, |
| 172 | + ) |
| 173 | + except OperationFailure as ex: |
| 174 | + raise TypeError("Cannot apply incr to a value of non-numeric type") from ex |
| 175 | + # Not exists |
| 176 | + if updated is None: |
| 177 | + raise ValueError(f"Key '{key}' not found.") from None |
| 178 | + return updated["value"] |
| 179 | + |
161 | 180 | def _get_expiration_time(self, timeout=None):
|
162 | 181 | if timeout is None:
|
163 | 182 | return None
|
|
0 commit comments