Skip to content

Commit b33c34e

Browse files
committed
progress
1 parent cced863 commit b33c34e

File tree

6 files changed

+104
-60
lines changed

6 files changed

+104
-60
lines changed

django_mongodb_backend/base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import os
33

44
from django.core.exceptions import ImproperlyConfigured
5+
from django.db import DEFAULT_DB_ALIAS
56
from django.db.backends.base.base import BaseDatabaseWrapper
7+
from django.db.backends.utils import debug_transaction
68
from django.utils.asyncio import async_unsafe
79
from django.utils.functional import cached_property
810
from pymongo.collection import Collection
@@ -32,6 +34,17 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
3234
pass
3335

3436

37+
def requires_transaction_support(func):
38+
"""Make a method a no-op if transactions aren't supported."""
39+
40+
def wrapper(self, *args, **kwargs):
41+
if not self.features._supports_transactions:
42+
return
43+
func(self, *args, **kwargs)
44+
45+
return wrapper
46+
47+
3548
class DatabaseWrapper(BaseDatabaseWrapper):
3649
data_types = {
3750
"AutoField": "int",
@@ -142,6 +155,10 @@ def _isnull_operator(a, b):
142155
ops_class = DatabaseOperations
143156
validation_class = DatabaseValidation
144157

158+
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
159+
super().__init__(settings_dict, alias=alias)
160+
self.session = None
161+
145162
def get_collection(self, name, **kwargs):
146163
collection = Collection(self.database, name, **kwargs)
147164
if self.queries_logged:
@@ -212,6 +229,10 @@ def close(self):
212229

213230
def close_pool(self):
214231
"""Close the MongoClient."""
232+
# Clear commit hooks and session.
233+
self.run_on_commit = []
234+
if self.session:
235+
self._end_session()
215236
connection = self.connection
216237
if connection is None:
217238
return
@@ -227,6 +248,56 @@ def close_pool(self):
227248
def cursor(self):
228249
return Cursor()
229250

251+
@requires_transaction_support
252+
def validate_no_broken_transaction(self):
253+
super().validate_no_broken_transaction()
254+
230255
def get_database_version(self):
231256
"""Return a tuple of the database's version."""
232257
return tuple(self.connection.server_info()["versionArray"])
258+
259+
@requires_transaction_support
260+
def _start_transaction(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
261+
# Besides @transaction.atomic() (which uses
262+
# _start_transaction_under_autocommit(), disabling autocommit is
263+
# another way to start a transaction.
264+
# if not autocommit:
265+
# self._start_transaction()
266+
# def _start_transaction(self):
267+
# Private API, specific to this backend.
268+
if self.session is None:
269+
self.session = self.connection.start_session()
270+
with debug_transaction(self, "session.start_transaction()"):
271+
self.session.start_transaction()
272+
273+
@requires_transaction_support
274+
def _commit_transaction(self):
275+
self.validate_thread_sharing()
276+
self.validate_no_atomic_block()
277+
if self.session:
278+
with debug_transaction(self, "session.commit_transaction()"):
279+
self.session.commit_transaction()
280+
self._end_session()
281+
# A successful commit means that the database connection works.
282+
self.errors_occurred = False
283+
self.run_commit_hooks_on_set_autocommit_on = True
284+
285+
@async_unsafe
286+
@requires_transaction_support
287+
def _rollback_transaction(self):
288+
"""Roll back a MongoDB transaction and reset the dirty flag."""
289+
self.validate_thread_sharing()
290+
self.validate_no_atomic_block()
291+
if self.session:
292+
with debug_transaction(self, "session.abort_transaction()"):
293+
self.session.abort_transaction()
294+
self._end_session()
295+
# A successful rollback means that the database connection works.
296+
self.errors_occurred = False
297+
self.needs_rollback = False
298+
self.run_on_commit = []
299+
300+
def _end_session(self):
301+
# Private API, specific to this backend.
302+
self.session.end_session()
303+
self.session = None

django_mongodb_backend/compiler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,10 @@ def execute_sql(self, returning_fields=None):
678678
@wrap_database_errors
679679
def insert(self, docs, returning_fields=None):
680680
"""Store a list of documents using field columns as element names."""
681-
inserted_ids = self.collection.insert_many(docs).inserted_ids
681+
self.connection.validate_no_broken_transaction()
682+
inserted_ids = self.collection.insert_many(
683+
docs, session=self.connection.session
684+
).inserted_ids
682685
return [(x,) for x in inserted_ids] if returning_fields else []
683686

684687
@cached_property
@@ -759,7 +762,10 @@ def execute_sql(self, result_type):
759762

760763
@wrap_database_errors
761764
def update(self, criteria, pipeline):
762-
return self.collection.update_many(criteria, pipeline).matched_count
765+
self.connection.validate_no_broken_transaction()
766+
return self.collection.update_many(
767+
criteria, pipeline, session=self.connection.session
768+
).matched_count
763769

764770
def check_query(self):
765771
super().check_query()

django_mongodb_backend/query.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,23 @@ def __repr__(self):
6161
@wrap_database_errors
6262
def delete(self):
6363
"""Execute a delete query."""
64+
self.compiler.connection.validate_no_broken_transaction()
6465
if self.compiler.subqueries:
6566
raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.")
66-
return self.compiler.collection.delete_many(self.match_mql).deleted_count
67+
return self.compiler.collection.delete_many(
68+
self.match_mql, session=self.compiler.connection.session
69+
).deleted_count
6770

6871
@wrap_database_errors
6972
def get_cursor(self):
7073
"""
7174
Return a pymongo CommandCursor that can be iterated on to give the
7275
results of the query.
7376
"""
74-
return self.compiler.collection.aggregate(self.get_pipeline())
77+
self.compiler.connection.validate_no_broken_transaction()
78+
return self.compiler.collection.aggregate(
79+
self.get_pipeline(), session=self.compiler.connection.session
80+
)
7581

7682
def get_pipeline(self):
7783
pipeline = []

django_mongodb_backend/queryset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, pipeline, using, model):
3535
def _execute_query(self):
3636
connection = connections[self.using]
3737
collection = connection.get_collection(self.model._meta.db_table)
38-
self.cursor = collection.aggregate(self.pipeline)
38+
self.cursor = collection.aggregate(self.pipeline, session=connection.session)
3939

4040
def __str__(self):
4141
return str(self.pipeline)

django_mongodb_backend/transaction.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,6 @@ def get_connection(using=None):
2323
return connections[using]
2424

2525

26-
def get_autocommit(using=None):
27-
"""Get the autocommit status of the connection."""
28-
return get_connection(using).get_autocommit()
29-
30-
31-
def set_autocommit(autocommit, using=None):
32-
"""Set the autocommit status of the connection."""
33-
return get_connection(using).set_autocommit(autocommit)
34-
35-
3626
def commit(using=None):
3727
"""Commit a transaction."""
3828
get_connection(using).commit()
@@ -141,7 +131,7 @@ def __enter__(self):
141131
and not connection.atomic_blocks[-1]._from_testcase
142132
):
143133
raise RuntimeError(
144-
"A durable atomic block cannot be nested within another " "atomic block."
134+
"A durable atomic block cannot be nested within another atomic block."
145135
)
146136
if not connection.in_atomic_block:
147137
# Reset state when entering an outermost atomic block.
@@ -158,7 +148,9 @@ def __enter__(self):
158148
# We're already in a transaction
159149
pass
160150
else:
161-
connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True)
151+
connection._start_transaction(
152+
False, force_begin_transaction_with_broken_autocommit=True
153+
)
162154
connection.in_atomic_block = True
163155

164156
if connection.in_atomic_block:
@@ -173,22 +165,22 @@ def __exit__(self, exc_type, exc_value, traceback):
173165
# Prematurely unset this flag to allow using commit or rollback.
174166
connection._in_atomic_block = False
175167
try:
176-
if connection._closed_in_transaction:
168+
if connection.closed_in_transaction:
177169
# The database will perform a rollback by itself.
178170
# Wait until we exit the outermost block.
179171
pass
180172

181-
elif exc_type is None and not connection._needs_rollback:
173+
elif exc_type is None and not connection.needs_rollback:
182174
if connection._in_atomic_block:
183175
# Release savepoint if there is one
184176
pass
185177
else:
186178
# Commit transaction
187179
try:
188-
connection._commit()
180+
connection._commit_transaction()
189181
except DatabaseError:
190182
try:
191-
connection._rollback()
183+
connection._rollback_transaction()
192184
except Error:
193185
# An error during rollback means that something
194186
# went wrong with the connection. Drop it.
@@ -204,7 +196,7 @@ def __exit__(self, exc_type, exc_value, traceback):
204196
else:
205197
# Roll back transaction
206198
try:
207-
connection.rollback()
199+
connection._rollback_transaction()
208200
except Error:
209201
# An error during rollback means that something
210202
# went wrong with the connection. Drop it.
@@ -214,8 +206,8 @@ def __exit__(self, exc_type, exc_value, traceback):
214206
if not connection.in_atomic_block:
215207
if connection.closed_in_transaction:
216208
connection.connection = None
217-
else:
218-
connection.set_autocommit(True)
209+
# else:
210+
# connection.set_autocommit(True)
219211
# Outermost block exit when autocommit was disabled.
220212
elif not connection.commit_on_exit:
221213
if connection.closed_in_transaction:

tests/transactions_/tests.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -103,40 +103,6 @@ def test_nested_rollback_rollback(self):
103103
raise Exception("Oops, that's his first name")
104104
self.assertSequenceEqual(Reporter.objects.all(), [])
105105

106-
def test_merged_commit_commit(self):
107-
with transaction.atomic():
108-
reporter1 = Reporter.objects.create(first_name="Tintin")
109-
with transaction.atomic(savepoint=False):
110-
reporter2 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
111-
self.assertSequenceEqual(Reporter.objects.all(), [reporter2, reporter1])
112-
113-
def test_merged_commit_rollback(self):
114-
with transaction.atomic():
115-
Reporter.objects.create(first_name="Tintin")
116-
with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(savepoint=False):
117-
Reporter.objects.create(first_name="Haddock")
118-
raise Exception("Oops, that's his last name")
119-
# Writes in the outer block are rolled back too.
120-
self.assertSequenceEqual(Reporter.objects.all(), [])
121-
122-
def test_merged_rollback_commit(self):
123-
with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic():
124-
Reporter.objects.create(last_name="Tintin")
125-
with transaction.atomic(savepoint=False):
126-
Reporter.objects.create(last_name="Haddock")
127-
raise Exception("Oops, that's his first name")
128-
self.assertSequenceEqual(Reporter.objects.all(), [])
129-
130-
def test_merged_rollback_rollback(self):
131-
with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic():
132-
Reporter.objects.create(last_name="Tintin")
133-
with self.assertRaisesMessage(Exception, "Oops"):
134-
with transaction.atomic(savepoint=False):
135-
Reporter.objects.create(first_name="Haddock")
136-
raise Exception("Oops, that's his last name")
137-
raise Exception("Oops, that's his first name")
138-
self.assertSequenceEqual(Reporter.objects.all(), [])
139-
140106
def test_reuse_commit_commit(self):
141107
atomic = transaction.atomic()
142108
with atomic:
@@ -212,8 +178,11 @@ class AtomicErrorsTests(TransactionTestCase):
212178

213179
def test_atomic_prevents_setting_autocommit(self):
214180
autocommit = transaction.get_autocommit()
215-
with transaction.atomic(), self.assertRaisesMessage(
216-
transaction.TransactionManagementError, self.forbidden_atomic_msg
181+
with (
182+
transaction.atomic(),
183+
self.assertRaisesMessage(
184+
transaction.TransactionManagementError, self.forbidden_atomic_msg
185+
),
217186
):
218187
transaction.set_autocommit(not autocommit)
219188
# Make sure autocommit wasn't changed.

0 commit comments

Comments
 (0)