Skip to content

Commit e873721

Browse files
committed
add transaction hooks
1 parent f68e459 commit e873721

File tree

6 files changed

+246
-10
lines changed

6 files changed

+246
-10
lines changed

django_mongodb_backend/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import logging
23
import os
34

45
from django.core.exceptions import ImproperlyConfigured
@@ -46,6 +47,9 @@ def wrapper(self, *args, **kwargs):
4647
return wrapper
4748

4849

50+
logger = logging.getLogger("django.db.backends.base")
51+
52+
4953
class DatabaseWrapper(BaseDatabaseWrapper):
5054
data_types = {
5155
"AutoField": "int",
@@ -319,3 +323,24 @@ def _end_session(self):
319323
# Private API, specific to this backend.
320324
self.session.end_session()
321325
self.session = None
326+
327+
def on_commit(self, func, robust=False):
328+
if not callable(func):
329+
raise TypeError("on_commit()'s callback must be a callable.")
330+
if self.in_atomic_block_mongo:
331+
# Transaction in progress; save for execution on commit.
332+
self.run_on_commit.append((set(self.savepoint_ids), func, robust))
333+
else:
334+
# No transaction in progress and in autocommit mode; execute
335+
# immediately.
336+
if robust:
337+
try:
338+
func()
339+
except Exception as e:
340+
logger.exception(
341+
"Error calling %s in on_commit() (%s).",
342+
func.__qualname__,
343+
e,
344+
)
345+
else:
346+
func()

django_mongodb_backend/transaction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def __exit__(self, exc_type, exc_value, traceback):
115115
finally:
116116
# Outermost block exit when autocommit was enabled.
117117
if not connection.in_atomic_block_mongo:
118-
pass
118+
if connection.run_commit_hooks_on_set_autocommit_on:
119+
connection.run_and_clear_commit_hooks()
119120
# connection.set_autocommit(True)
120121
# Outermost block exit when autocommit was disabled.
121122
elif not connection.commit_on_exit:

tests/transaction_hooks_/__init__.py

Whitespace-only changes.

tests/transaction_hooks_/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from django.db import models
2+
3+
4+
class Thing(models.Model):
5+
num = models.IntegerField()

tests/transaction_hooks_/tests.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from django.db import connection
2+
from django.test import TransactionTestCase, skipUnlessDBFeature
3+
4+
from django_mongodb_backend import transaction
5+
6+
from .models import Thing
7+
8+
9+
class ForcedError(Exception):
10+
pass
11+
12+
13+
@skipUnlessDBFeature("_supports_transactions")
14+
class TestConnectionOnCommit(TransactionTestCase):
15+
"""
16+
Tests for transaction.on_commit().
17+
18+
Creation/checking of database objects in parallel with callback tracking is
19+
to verify that the behavior of the two match in all tested cases.
20+
"""
21+
22+
available_apps = ["transaction_hooks_"]
23+
24+
def setUp(self):
25+
self.notified = []
26+
27+
def notify(self, id_):
28+
if id_ == "error":
29+
raise ForcedError()
30+
self.notified.append(id_)
31+
32+
def do(self, num):
33+
"""Create a Thing instance and notify about it."""
34+
Thing.objects.create(num=num)
35+
transaction.on_commit(lambda: self.notify(num))
36+
37+
def assertDone(self, nums):
38+
self.assertNotified(nums)
39+
self.assertEqual(sorted(t.num for t in Thing.objects.all()), sorted(nums))
40+
41+
def assertNotified(self, nums):
42+
self.assertEqual(self.notified, nums)
43+
44+
def test_executes_immediately_if_no_transaction(self):
45+
self.do(1)
46+
self.assertDone([1])
47+
48+
def test_robust_if_no_transaction(self):
49+
def robust_callback():
50+
raise ForcedError("robust callback")
51+
52+
with self.assertLogs("django.db.backends.base", "ERROR") as cm:
53+
transaction.on_commit(robust_callback, robust=True)
54+
self.do(1)
55+
56+
self.assertDone([1])
57+
log_record = cm.records[0]
58+
self.assertEqual(
59+
log_record.getMessage(),
60+
"Error calling TestConnectionOnCommit.test_robust_if_no_transaction."
61+
"<locals>.robust_callback in on_commit() (robust callback).",
62+
)
63+
self.assertIsNotNone(log_record.exc_info)
64+
raised_exception = log_record.exc_info[1]
65+
self.assertIsInstance(raised_exception, ForcedError)
66+
self.assertEqual(str(raised_exception), "robust callback")
67+
68+
def test_robust_transaction(self):
69+
def robust_callback():
70+
raise ForcedError("robust callback")
71+
72+
with self.assertLogs("django.db.backends", "ERROR") as cm, transaction.atomic():
73+
transaction.on_commit(robust_callback, robust=True)
74+
self.do(1)
75+
76+
self.assertDone([1])
77+
log_record = cm.records[0]
78+
self.assertEqual(
79+
log_record.getMessage(),
80+
"Error calling TestConnectionOnCommit.test_robust_transaction.<locals>."
81+
"robust_callback in on_commit() during transaction (robust callback).",
82+
)
83+
self.assertIsNotNone(log_record.exc_info)
84+
raised_exception = log_record.exc_info[1]
85+
self.assertIsInstance(raised_exception, ForcedError)
86+
self.assertEqual(str(raised_exception), "robust callback")
87+
88+
def test_delays_execution_until_after_transaction_commit(self):
89+
with transaction.atomic():
90+
self.do(1)
91+
self.assertNotified([])
92+
self.assertDone([1])
93+
94+
def test_does_not_execute_if_transaction_rolled_back(self):
95+
try:
96+
with transaction.atomic():
97+
self.do(1)
98+
raise ForcedError()
99+
except ForcedError:
100+
pass
101+
102+
self.assertDone([])
103+
104+
def test_executes_only_after_final_transaction_committed(self):
105+
with transaction.atomic():
106+
with transaction.atomic():
107+
self.do(1)
108+
self.assertNotified([])
109+
self.assertNotified([])
110+
self.assertDone([1])
111+
112+
def test_no_hooks_run_from_failed_transaction(self):
113+
"""If outer transaction fails, no hooks from within it run."""
114+
try:
115+
with transaction.atomic():
116+
with transaction.atomic():
117+
self.do(1)
118+
raise ForcedError()
119+
except ForcedError:
120+
pass
121+
122+
self.assertDone([])
123+
124+
def test_runs_hooks_in_order_registered(self):
125+
with transaction.atomic():
126+
self.do(1)
127+
with transaction.atomic():
128+
self.do(2)
129+
self.do(3)
130+
131+
self.assertDone([1, 2, 3])
132+
133+
def test_hooks_cleared_after_successful_commit(self):
134+
with transaction.atomic():
135+
self.do(1)
136+
with transaction.atomic():
137+
self.do(2)
138+
139+
self.assertDone([1, 2]) # not [1, 1, 2]
140+
141+
def test_hooks_cleared_after_rollback(self):
142+
try:
143+
with transaction.atomic():
144+
self.do(1)
145+
raise ForcedError()
146+
except ForcedError:
147+
pass
148+
149+
with transaction.atomic():
150+
self.do(2)
151+
152+
self.assertDone([2])
153+
154+
@skipUnlessDBFeature("test_db_allows_multiple_connections")
155+
def test_hooks_cleared_on_reconnect(self):
156+
with transaction.atomic():
157+
self.do(1)
158+
connection.close_pool()
159+
160+
connection.connect()
161+
162+
with transaction.atomic():
163+
self.do(2)
164+
165+
self.assertDone([2])
166+
167+
def test_error_in_hook_doesnt_prevent_clearing_hooks(self):
168+
try:
169+
with transaction.atomic():
170+
transaction.on_commit(lambda: self.notify("error"))
171+
except ForcedError:
172+
pass
173+
174+
with transaction.atomic():
175+
self.do(1)
176+
177+
self.assertDone([1])
178+
179+
def test_db_query_in_hook(self):
180+
with transaction.atomic():
181+
Thing.objects.create(num=1)
182+
transaction.on_commit(lambda: [self.notify(t.num) for t in Thing.objects.all()])
183+
184+
self.assertDone([1])
185+
186+
def test_transaction_in_hook(self):
187+
def on_commit():
188+
with transaction.atomic():
189+
t = Thing.objects.create(num=1)
190+
self.notify(t.num)
191+
192+
with transaction.atomic():
193+
transaction.on_commit(on_commit)
194+
195+
self.assertDone([1])
196+
197+
def test_hook_in_hook(self):
198+
def on_commit(i, add_hook):
199+
with transaction.atomic():
200+
if add_hook:
201+
transaction.on_commit(lambda: on_commit(i + 10, False))
202+
t = Thing.objects.create(num=i)
203+
self.notify(t.num)
204+
205+
with transaction.atomic():
206+
transaction.on_commit(lambda: on_commit(1, True))
207+
transaction.on_commit(lambda: on_commit(2, True))
208+
209+
self.assertDone([1, 11, 2, 12])
210+
211+
def test_raises_exception_non_callable(self):
212+
msg = "on_commit()'s callback must be a callable."
213+
with self.assertRaisesMessage(TypeError, msg):
214+
transaction.on_commit(None)

tests/transactions_/models.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
"""
2-
Transactions
3-
4-
Django handles transactions in three different ways. The default is to commit
5-
each transaction upon a write, but you can decorate a function to get
6-
commit-on-success behavior. Alternatively, you can manage the transaction
7-
manually.
8-
"""
9-
101
from django.db import models
112

123

0 commit comments

Comments
 (0)