@@ -13,10 +13,7 @@ def __init__(self, protocol=None):
13
13
def dumps (self , obj ):
14
14
if isinstance (obj , int ):
15
15
return obj
16
- try :
17
- return pickle .dumps (obj , self .protocol )
18
- except pickle .PickleError as ex :
19
- raise ValueError ("Object cannot be pickled" ) from ex
16
+ return pickle .dumps (obj , self .protocol )
20
17
21
18
def loads (self , data ):
22
19
try :
@@ -60,19 +57,58 @@ class CacheEntry:
60
57
class MongoDBCache (BaseDatabaseCache ):
61
58
def __init__ (self , * args , ** options ):
62
59
super ().__init__ (* args , ** options )
60
+ # don't know If I can set the capped collection here.
61
+ coll_info = self .collection .options ()
62
+ collections = set (self ._db .database .list_collection_names ())
63
+ coll_exists = self ._collection_name in collections
64
+ if coll_exists and not coll_info .get ("capped" , False ):
65
+ self ._db .database .command (
66
+ "convertToCapped" , self ._collection_name , size = self ._max_entries
67
+ )
68
+ elif coll_exists and coll_info .get ("size" ) != self ._max_entries :
69
+ new_coll = self ._copy_collection ()
70
+ self .collection .drop ()
71
+ new_coll .rename (self ._collection_name )
72
+ self .create_indexes ()
73
+
74
+ def create_indexes (self ):
75
+ self .collection .create_index ("expire_at" , expireAfterSeconds = 0 )
76
+ self .collection .create_index ("key" , unique = True )
77
+
78
+ def _copy_collection (self ):
79
+ collection_name = self ._get_tmp_collection_name ()
80
+ self .collection .aggregate ([{"$out" : collection_name }])
81
+ return self ._db .get_collection (collection_name )
82
+
83
+ def _get_tmp_collection_name (self ):
84
+ collections = set (self ._db .database .list_collection_names ())
85
+ template_collection_name = "tmp__collection__{num}"
86
+ num = 0
87
+ while True :
88
+ tmp_collection_name = template_collection_name .format (num = num )
89
+ if tmp_collection_name not in collections :
90
+ break
91
+ num += 1
92
+ return tmp_collection_name
63
93
64
94
@cached_property
65
95
def serializer (self ):
66
96
return MongoSerializer ()
67
97
68
- @cached_property
98
+ @property
99
+ def _db (self ):
100
+ return connections [router .db_for_read (self .cache_model_class )]
101
+
102
+ @property
69
103
def collection (self ):
70
- db = router .db_for_read (self .cache_model_class )
71
- return connections [db ].get_collection (self ._collection_name )
104
+ return self ._db .get_collection (self ._collection_name )
72
105
73
106
def get (self , key , default = None , version = None ):
74
107
key = self .make_and_validate_key (key , version = version )
75
- return self .collection .find_one ({"key" : key }) or default
108
+ result = self .collection .find_one ({"key" : key })
109
+ if result is not None :
110
+ return self .serializer .loads (result ["value" ])
111
+ return default
76
112
77
113
def get_many (self , keys , version = None ):
78
114
if not keys :
@@ -86,8 +122,14 @@ def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
86
122
serialized_data = self .serializer .dumps (value )
87
123
return self .collection .update_one (
88
124
{"key" : key },
89
- {"key" : key , "value" : serialized_data , "expire_at" : self ._get_expiration_time (timeout )},
90
- {"upsert" : True },
125
+ {
126
+ "$set" : {
127
+ "key" : key ,
128
+ "value" : serialized_data ,
129
+ "expire_at" : self ._get_expiration_time (timeout ),
130
+ }
131
+ },
132
+ True ,
91
133
)
92
134
93
135
def add (self , key , value , timeout = DEFAULT_TIMEOUT , version = None ):
@@ -130,7 +172,7 @@ def delete_many(self, keys, version=None):
130
172
131
173
def has_key (self , key , version = None ):
132
174
key = self .make_and_validate_key (key , version = version )
133
- return self .collection .count ({"key" : key }) > 0
175
+ return self .collection .count_documents ({"key" : key }) > 0
134
176
135
177
def clear (self ):
136
178
self .collection .delete_many ({})
0 commit comments