2
2
import os
3
3
4
4
from django .core .exceptions import ImproperlyConfigured
5
+ from django .db import DEFAULT_DB_ALIAS
5
6
from django .db .backends .base .base import BaseDatabaseWrapper
7
+ from django .db .backends .utils import debug_transaction
6
8
from django .utils .asyncio import async_unsafe
7
9
from django .utils .functional import cached_property
8
10
from pymongo .collection import Collection
@@ -32,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
32
34
pass
33
35
34
36
37
+ def requires_transaction_support (func ):
38
+ """Make a method a no-op if transactions aren't supported."""
39
+
40
+ def wrapper (self , * args , ** kwargs ):
41
+ if not self .features ._supports_transactions :
42
+ return
43
+ func (self , * args , ** kwargs )
44
+
45
+ return wrapper
46
+
47
+
35
48
class DatabaseWrapper (BaseDatabaseWrapper ):
36
49
data_types = {
37
50
"AutoField" : "int" ,
@@ -142,6 +155,10 @@ def _isnull_operator(a, b):
142
155
ops_class = DatabaseOperations
143
156
validation_class = DatabaseValidation
144
157
158
+ def __init__ (self , settings_dict , alias = DEFAULT_DB_ALIAS ):
159
+ super ().__init__ (settings_dict , alias = alias )
160
+ self .session = None
161
+
145
162
def get_collection (self , name , ** kwargs ):
146
163
collection = Collection (self .database , name , ** kwargs )
147
164
if self .queries_logged :
@@ -212,6 +229,10 @@ def close(self):
212
229
213
230
def close_pool (self ):
214
231
"""Close the MongoClient."""
232
+ # Clear commit hooks and session.
233
+ self .run_on_commit = []
234
+ if self .session :
235
+ self ._end_session ()
215
236
connection = self .connection
216
237
if connection is None :
217
238
return
@@ -227,6 +248,56 @@ def close_pool(self):
227
248
def cursor (self ):
228
249
return Cursor ()
229
250
251
+ @requires_transaction_support
252
+ def validate_no_broken_transaction (self ):
253
+ super ().validate_no_broken_transaction ()
254
+
230
255
def get_database_version (self ):
231
256
"""Return a tuple of the database's version."""
232
257
return tuple (self .connection .server_info ()["versionArray" ])
258
+
259
+ @requires_transaction_support
260
+ def _start_transaction (self , autocommit , force_begin_transaction_with_broken_autocommit = False ):
261
+ # Besides @transaction.atomic() (which uses
262
+ # _start_transaction_under_autocommit(), disabling autocommit is
263
+ # another way to start a transaction.
264
+ # if not autocommit:
265
+ # self._start_transaction()
266
+ # def _start_transaction(self):
267
+ # Private API, specific to this backend.
268
+ if self .session is None :
269
+ self .session = self .connection .start_session ()
270
+ with debug_transaction (self , "session.start_transaction()" ):
271
+ self .session .start_transaction ()
272
+
273
+ @requires_transaction_support
274
+ def _commit_transaction (self ):
275
+ self .validate_thread_sharing ()
276
+ self .validate_no_atomic_block ()
277
+ if self .session :
278
+ with debug_transaction (self , "session.commit_transaction()" ):
279
+ self .session .commit_transaction ()
280
+ self ._end_session ()
281
+ # A successful commit means that the database connection works.
282
+ self .errors_occurred = False
283
+ self .run_commit_hooks_on_set_autocommit_on = True
284
+
285
+ @async_unsafe
286
+ @requires_transaction_support
287
+ def _rollback_transaction (self ):
288
+ """Roll back a MongoDB transaction and reset the dirty flag."""
289
+ self .validate_thread_sharing ()
290
+ self .validate_no_atomic_block ()
291
+ if self .session :
292
+ with debug_transaction (self , "session.abort_transaction()" ):
293
+ self .session .abort_transaction ()
294
+ self ._end_session ()
295
+ # A successful rollback means that the database connection works.
296
+ self .errors_occurred = False
297
+ self .needs_rollback = False
298
+ self .run_on_commit = []
299
+
300
+ def _end_session (self ):
301
+ # Private API, specific to this backend.
302
+ self .session .end_session ()
303
+ self .session = None
0 commit comments