22
33import contextlib
44import functools
5- from typing import TYPE_CHECKING
5+ from typing import TYPE_CHECKING , overload
66
77import attrs
88from django import db as django_db
2828]
2929
3030
31- @contextlib .contextmanager
32- def transaction (* , using : str | None = None ) -> Generator [None ]:
31+ @overload
32+ def transaction (
33+ func : None = None , * , using : str | None = None
34+ ) -> contextlib ._GeneratorContextManager [None , None , None ]: ...
35+
36+
37+ @overload
38+ def transaction [** P , R ](
39+ func : Callable [P , R ], * , using : str | None = None
40+ ) -> Callable [P , R ]: ...
41+
42+
43+ def transaction [** P , R ](
44+ func : Callable [P , R ] | None = None , * , using : str | None = None
45+ ) -> contextlib ._GeneratorContextManager [None , None , None ] | Callable [P , R ]:
3346 """
3447 Create a database transaction.
3548
@@ -41,17 +54,38 @@ def transaction(*, using: str | None = None) -> Generator[None]:
4154 Raises:
4255 RuntimeError: if we call this from inside another existing transaction.
4356 """
44- # Note that `savepoint=False` is not required here because
45- # the `savepoint` flag is ignored when `durable` is `True`.
46- with (
47- _execute_on_commit_callbacks_in_tests (using ),
48- django_transaction .atomic (using = using , durable = True ),
49- ):
50- yield
5157
58+ @contextlib .contextmanager
59+ def _transaction (* , using : str | None ) -> Generator [None ]:
60+ # Note that `savepoint=False` is not required here because
61+ # the `savepoint` flag is ignored when `durable` is `True`.
62+ with (
63+ _execute_on_commit_callbacks_in_tests (using ),
64+ django_transaction .atomic (using = using , durable = True ),
65+ ):
66+ yield
5267
53- @contextlib .contextmanager
54- def transaction_if_not_already (* , using : str | None = None ) -> Generator [None ]:
68+ decorator = _transaction (using = using )
69+ if func is None :
70+ return decorator
71+ return decorator (func )
72+
73+
74+ @overload
75+ def transaction_if_not_already (
76+ func : None = None , * , using : str | None = None
77+ ) -> contextlib ._GeneratorContextManager [None , None , None ]: ...
78+
79+
80+ @overload
81+ def transaction_if_not_already [** P , R ](
82+ func : Callable [P , R ], * , using : str | None = None
83+ ) -> Callable [P , R ]: ...
84+
85+
86+ def transaction_if_not_already [** P , R ](
87+ func : Callable [P , R ] | None = None , * , using : str | None = None
88+ ) -> contextlib ._GeneratorContextManager [None , None , None ] | Callable [P , R ]:
5589 """
5690 Create a transaction if one isn't already open.
5791
@@ -78,16 +112,24 @@ def transaction_if_not_already(*, using: str | None = None) -> Generator[None]:
78112 - In functions which can unambiguously control transactions,
79113 use [`transaction`][django_subatomic.db.transaction].
80114 """
81- # If the innermost atomic block is from a test case, we should create a SAVEPOINT here.
82- # This allows for a rollback when an exception propagates out of this block, and so
83- # better simulates a production transaction behaviour in tests.
84- savepoint = _innermost_atomic_block_wraps_testcase (using = using )
85115
86- with (
87- _execute_on_commit_callbacks_in_tests (using ),
88- django_transaction .atomic (using = using , savepoint = savepoint ),
89- ):
90- yield
116+ @contextlib .contextmanager
117+ def _transaction_if_not_already (* , using : str | None = None ) -> Generator [None ]:
118+ # If the innermost atomic block is from a test case, we should create a SAVEPOINT here.
119+ # This allows for a rollback when an exception propagates out of this block, and so
120+ # better simulates a production transaction behaviour in tests.
121+ savepoint = _innermost_atomic_block_wraps_testcase (using = using )
122+
123+ with (
124+ _execute_on_commit_callbacks_in_tests (using ),
125+ django_transaction .atomic (using = using , savepoint = savepoint ),
126+ ):
127+ yield
128+
129+ decorator = _transaction_if_not_already (using = using )
130+ if func is None :
131+ return decorator
132+ return decorator (func )
91133
92134
93135@_utils .contextmanager
@@ -118,8 +160,21 @@ def savepoint(*, using: str | None = None) -> Generator[None]:
118160 yield
119161
120162
121- @contextlib .contextmanager
122- def transaction_required (* , using : str | None = None ) -> Generator [None ]:
163+ @overload
164+ def transaction_required (
165+ func : None = None , * , using : str | None = None
166+ ) -> contextlib ._GeneratorContextManager [None , None , None ]: ...
167+
168+
169+ @overload
170+ def transaction_required [** P , R ](
171+ func : Callable [P , R ], * , using : str | None = None
172+ ) -> Callable [P , R ]: ...
173+
174+
175+ def transaction_required [** P , R ](
176+ func : Callable [P , R ] | None = None , * , using : str | None = None
177+ ) -> contextlib ._GeneratorContextManager [None , None , None ] | Callable [P , R ]:
123178 """
124179 Make sure that code is always executed in a transaction.
125180
@@ -132,12 +187,20 @@ def transaction_required(*, using: str | None = None) -> Generator[None]:
132187 Raises:
133188 _MissingRequiredTransaction: if we are not in a transaction.
134189 """
135- if using is None :
136- using = django_db .DEFAULT_DB_ALIAS
137190
138- if not in_transaction (using = using ):
139- raise _MissingRequiredTransaction (database = using )
140- yield
191+ @contextlib .contextmanager
192+ def _transaction_required (* , using : str | None = None ) -> Generator [None ]:
193+ if using is None :
194+ using = django_db .DEFAULT_DB_ALIAS
195+
196+ if not in_transaction (using = using ):
197+ raise _MissingRequiredTransaction (database = using )
198+ yield
199+
200+ decorator = _transaction_required (using = using )
201+ if func is None :
202+ return decorator
203+ return decorator (func )
141204
142205
143206def durable [** P , R ](func : Callable [P , R ]) -> Callable [P , R ]:
0 commit comments