Skip to content

Commit 6296fed

Browse files
authored
Merge pull request #117 from kraken-tech/support-decorators-without-parentheses
Support decorators without parentheses
2 parents 681e781 + 34e7654 commit 6296fed

File tree

3 files changed

+163
-28
lines changed

3 files changed

+163
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
1515
This helps catch order-of-execution bugs in tests.
1616
The error can be silenced by setting the `SUBATOMIC_CATCH_UNHANDLED_AFTER_COMMIT_CALLBACKS_IN_TESTS` setting to `False`
1717
to facilitate gradual adoption of this stricter rule.
18+
- The `transaction`, `transaction_required` and `transaction_if_not_already` decorators can now be used without parentheses. Fixes #103.
1819

1920
### Fixed
2021

src/django_subatomic/db.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import contextlib
44
import functools
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, overload
66

77
import attrs
88
from django import db as django_db
@@ -28,8 +28,21 @@
2828
]
2929

3030

31-
@contextlib.contextmanager
32-
def transaction(*, using: str | None = None) -> Generator[None]:
31+
@overload
32+
def transaction(
33+
func: None = None, *, using: str | None = None
34+
) -> contextlib._GeneratorContextManager[None, None, None]: ...
35+
36+
37+
@overload
38+
def transaction[**P, R](
39+
func: Callable[P, R], *, using: str | None = None
40+
) -> Callable[P, R]: ...
41+
42+
43+
def transaction[**P, R](
44+
func: Callable[P, R] | None = None, *, using: str | None = None
45+
) -> contextlib._GeneratorContextManager[None, None, None] | Callable[P, R]:
3346
"""
3447
Create a database transaction.
3548
@@ -41,17 +54,38 @@ def transaction(*, using: str | None = None) -> Generator[None]:
4154
Raises:
4255
RuntimeError: if we call this from inside another existing transaction.
4356
"""
44-
# Note that `savepoint=False` is not required here because
45-
# the `savepoint` flag is ignored when `durable` is `True`.
46-
with (
47-
_execute_on_commit_callbacks_in_tests(using),
48-
django_transaction.atomic(using=using, durable=True),
49-
):
50-
yield
5157

58+
@contextlib.contextmanager
59+
def _transaction(*, using: str | None) -> Generator[None]:
60+
# Note that `savepoint=False` is not required here because
61+
# the `savepoint` flag is ignored when `durable` is `True`.
62+
with (
63+
_execute_on_commit_callbacks_in_tests(using),
64+
django_transaction.atomic(using=using, durable=True),
65+
):
66+
yield
5267

53-
@contextlib.contextmanager
54-
def transaction_if_not_already(*, using: str | None = None) -> Generator[None]:
68+
decorator = _transaction(using=using)
69+
if func is None:
70+
return decorator
71+
return decorator(func)
72+
73+
74+
@overload
75+
def transaction_if_not_already(
76+
func: None = None, *, using: str | None = None
77+
) -> contextlib._GeneratorContextManager[None, None, None]: ...
78+
79+
80+
@overload
81+
def transaction_if_not_already[**P, R](
82+
func: Callable[P, R], *, using: str | None = None
83+
) -> Callable[P, R]: ...
84+
85+
86+
def transaction_if_not_already[**P, R](
87+
func: Callable[P, R] | None = None, *, using: str | None = None
88+
) -> contextlib._GeneratorContextManager[None, None, None] | Callable[P, R]:
5589
"""
5690
Create a transaction if one isn't already open.
5791
@@ -78,16 +112,24 @@ def transaction_if_not_already(*, using: str | None = None) -> Generator[None]:
78112
- In functions which can unambiguously control transactions,
79113
use [`transaction`][django_subatomic.db.transaction].
80114
"""
81-
# If the innermost atomic block is from a test case, we should create a SAVEPOINT here.
82-
# This allows for a rollback when an exception propagates out of this block, and so
83-
# better simulates a production transaction behaviour in tests.
84-
savepoint = _innermost_atomic_block_wraps_testcase(using=using)
85115

86-
with (
87-
_execute_on_commit_callbacks_in_tests(using),
88-
django_transaction.atomic(using=using, savepoint=savepoint),
89-
):
90-
yield
116+
@contextlib.contextmanager
117+
def _transaction_if_not_already(*, using: str | None = None) -> Generator[None]:
118+
# If the innermost atomic block is from a test case, we should create a SAVEPOINT here.
119+
# This allows for a rollback when an exception propagates out of this block, and so
120+
# better simulates a production transaction behaviour in tests.
121+
savepoint = _innermost_atomic_block_wraps_testcase(using=using)
122+
123+
with (
124+
_execute_on_commit_callbacks_in_tests(using),
125+
django_transaction.atomic(using=using, savepoint=savepoint),
126+
):
127+
yield
128+
129+
decorator = _transaction_if_not_already(using=using)
130+
if func is None:
131+
return decorator
132+
return decorator(func)
91133

92134

93135
@_utils.contextmanager
@@ -118,8 +160,21 @@ def savepoint(*, using: str | None = None) -> Generator[None]:
118160
yield
119161

120162

121-
@contextlib.contextmanager
122-
def transaction_required(*, using: str | None = None) -> Generator[None]:
163+
@overload
164+
def transaction_required(
165+
func: None = None, *, using: str | None = None
166+
) -> contextlib._GeneratorContextManager[None, None, None]: ...
167+
168+
169+
@overload
170+
def transaction_required[**P, R](
171+
func: Callable[P, R], *, using: str | None = None
172+
) -> Callable[P, R]: ...
173+
174+
175+
def transaction_required[**P, R](
176+
func: Callable[P, R] | None = None, *, using: str | None = None
177+
) -> contextlib._GeneratorContextManager[None, None, None] | Callable[P, R]:
123178
"""
124179
Make sure that code is always executed in a transaction.
125180
@@ -132,12 +187,20 @@ def transaction_required(*, using: str | None = None) -> Generator[None]:
132187
Raises:
133188
_MissingRequiredTransaction: if we are not in a transaction.
134189
"""
135-
if using is None:
136-
using = django_db.DEFAULT_DB_ALIAS
137190

138-
if not in_transaction(using=using):
139-
raise _MissingRequiredTransaction(database=using)
140-
yield
191+
@contextlib.contextmanager
192+
def _transaction_required(*, using: str | None = None) -> Generator[None]:
193+
if using is None:
194+
using = django_db.DEFAULT_DB_ALIAS
195+
196+
if not in_transaction(using=using):
197+
raise _MissingRequiredTransaction(database=using)
198+
yield
199+
200+
decorator = _transaction_required(using=using)
201+
if func is None:
202+
return decorator
203+
return decorator(func)
141204

142205

143206
def durable[**P, R](func: Callable[P, R]) -> Callable[P, R]:

tests/test_db.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,25 @@ def inner() -> None:
118118
assert was_called is True
119119
assert django_transaction.get_autocommit() is True
120120

121+
@pytest.mark.django_db(transaction=True)
122+
def test_decorator_without_parentheses(self) -> None:
123+
"""
124+
`transaction` can be used as a decorator without parentheses.
125+
"""
126+
was_called = False
127+
128+
@db.transaction
129+
def inner() -> None:
130+
assert django_transaction.get_autocommit() is False
131+
nonlocal was_called
132+
was_called = True
133+
134+
assert django_transaction.get_autocommit() is True
135+
inner()
136+
137+
assert was_called is True
138+
assert django_transaction.get_autocommit() is True
139+
121140
def test_works_in_tests(self) -> None:
122141
"""
123142
Tests can call `db.transaction` without a fuss.
@@ -317,6 +336,20 @@ def inner() -> None: ...
317336

318337
assert exc.value.database == DEFAULT
319338

339+
@_parametrize_transaction_testcase
340+
def test_decorator_without_parentheses_fails_when_not_in_transaction(self) -> None:
341+
"""
342+
An error is raised when we're not in a transaction.
343+
"""
344+
345+
@db.transaction_required
346+
def inner() -> None: ...
347+
348+
with pytest.raises(db._MissingRequiredTransaction) as exc: # noqa: SLF001
349+
inner()
350+
351+
assert exc.value.database == DEFAULT
352+
320353
@_parametrize_transaction_testcase
321354
def test_no_error_when_in_transaction(self) -> None:
322355
"""
@@ -432,6 +465,44 @@ def test_can_query_after_exception_in_test_case(self) -> None:
432465
with django_db.connections["default"].cursor() as cursor:
433466
cursor.execute("SELECT 1")
434467

468+
@pytest.mark.django_db(transaction=True)
469+
def test_decorator(self) -> None:
470+
"""
471+
`transaction_if_not_already` can be used as a decorator.
472+
"""
473+
was_called = False
474+
475+
@db.transaction_if_not_already()
476+
def inner() -> None:
477+
assert django_transaction.get_autocommit() is False
478+
nonlocal was_called
479+
was_called = True
480+
481+
assert django_transaction.get_autocommit() is True
482+
inner()
483+
484+
assert was_called is True
485+
assert django_transaction.get_autocommit() is True
486+
487+
@pytest.mark.django_db(transaction=True)
488+
def test_decorator_without_parentheses(self) -> None:
489+
"""
490+
`transaction_if_not_already` can be used as a decorator without parentheses.
491+
"""
492+
was_called = False
493+
494+
@db.transaction_if_not_already
495+
def inner() -> None:
496+
assert django_transaction.get_autocommit() is False
497+
nonlocal was_called
498+
was_called = True
499+
500+
assert django_transaction.get_autocommit() is True
501+
inner()
502+
503+
assert was_called is True
504+
assert django_transaction.get_autocommit() is True
505+
435506

436507
@db.durable
437508
def _durable_example() -> bool:

0 commit comments

Comments
 (0)