2323from typing import (
2424 TYPE_CHECKING ,
2525 Any ,
26+ Awaitable ,
2627 Callable ,
2728 Collection ,
2829 Dict ,
5758from synapse .storage .background_updates import BackgroundUpdater
5859from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
5960from synapse .storage .types import Connection , Cursor
60- from synapse .util .async_helpers import delay_cancellation , maybe_awaitable
61+ from synapse .util .async_helpers import delay_cancellation
6162from synapse .util .iterutils import batch_iter
6263
6364if TYPE_CHECKING :
@@ -168,6 +169,7 @@ def cursor(
168169 * ,
169170 txn_name : Optional [str ] = None ,
170171 after_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
172+ async_after_callbacks : Optional [List ["_AsyncCallbackListEntry" ]] = None ,
171173 exception_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
172174 ) -> "LoggingTransaction" :
173175 if not txn_name :
@@ -178,6 +180,7 @@ def cursor(
178180 name = txn_name ,
179181 database_engine = self .engine ,
180182 after_callbacks = after_callbacks ,
183+ async_after_callbacks = async_after_callbacks ,
181184 exception_callbacks = exception_callbacks ,
182185 )
183186
@@ -209,6 +212,9 @@ def __getattr__(self, name: str) -> Any:
209212
210213# The type of entry which goes on our after_callbacks and exception_callbacks lists.
211214_CallbackListEntry = Tuple [Callable [..., object ], Tuple [object , ...], Dict [str , object ]]
215+ _AsyncCallbackListEntry = Tuple [
216+ Callable [..., Awaitable ], Tuple [object , ...], Dict [str , object ]
217+ ]
212218
213219P = ParamSpec ("P" )
214220R = TypeVar ("R" )
@@ -227,6 +233,10 @@ class LoggingTransaction:
227233 that have been added by `call_after` which should be run on
228234 successful completion of the transaction. None indicates that no
229235 callbacks should be allowed to be scheduled to run.
236+ async_after_callbacks: A list that asynchronous callbacks will be appended
237+ to by `async_call_after` which should run, before after_callbacks, on
238+ successful completion of the transaction. None indicates that no
239+ callbacks should be allowed to be scheduled to run.
230240 exception_callbacks: A list that callbacks will be appended
231241 to that have been added by `call_on_exception` which should be run
232242 if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
238248 "name" ,
239249 "database_engine" ,
240250 "after_callbacks" ,
251+ "async_after_callbacks" ,
241252 "exception_callbacks" ,
242253 ]
243254
@@ -247,12 +258,14 @@ def __init__(
247258 name : str ,
248259 database_engine : BaseDatabaseEngine ,
249260 after_callbacks : Optional [List [_CallbackListEntry ]] = None ,
261+ async_after_callbacks : Optional [List [_AsyncCallbackListEntry ]] = None ,
250262 exception_callbacks : Optional [List [_CallbackListEntry ]] = None ,
251263 ):
252264 self .txn = txn
253265 self .name = name
254266 self .database_engine = database_engine
255267 self .after_callbacks = after_callbacks
268+ self .async_after_callbacks = async_after_callbacks
256269 self .exception_callbacks = exception_callbacks
257270
258271 def call_after (
@@ -277,6 +290,28 @@ def call_after(
277290 # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
278291 self .after_callbacks .append ((callback , args , kwargs )) # type: ignore[arg-type]
279292
293+ def async_call_after (
294+ self , callback : Callable [P , Awaitable ], * args : P .args , ** kwargs : P .kwargs
295+ ) -> None :
296+ """Call the given asynchronous callback on the main twisted thread after
297+ the transaction has finished (but before those added in `call_after`).
298+
299+ Mostly used to invalidate remote caches after transactions.
300+
301+ Note that transactions may be retried a few times if they encounter database
302+ errors such as serialization failures. Callbacks given to `async_call_after`
303+ will accumulate across transaction attempts and will _all_ be called once a
304+ transaction attempt succeeds, regardless of whether previous transaction
305+ attempts failed. Otherwise, if all transaction attempts fail, all
306+ `call_on_exception` callbacks will be run instead.
307+ """
308+ # if self.async_after_callbacks is None, that means that whatever constructed the
309+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
310+ # is not the case.
311+ assert self .async_after_callbacks is not None
312+ # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
313+ self .async_after_callbacks .append ((callback , args , kwargs )) # type: ignore[arg-type]
314+
280315 def call_on_exception (
281316 self , callback : Callable [P , object ], * args : P .args , ** kwargs : P .kwargs
282317 ) -> None :
@@ -574,6 +609,7 @@ def new_transaction(
574609 conn : LoggingDatabaseConnection ,
575610 desc : str ,
576611 after_callbacks : List [_CallbackListEntry ],
612+ async_after_callbacks : List [_AsyncCallbackListEntry ],
577613 exception_callbacks : List [_CallbackListEntry ],
578614 func : Callable [Concatenate [LoggingTransaction , P ], R ],
579615 * args : P .args ,
@@ -597,6 +633,7 @@ def new_transaction(
597633 conn
598634 desc
599635 after_callbacks
636+ async_after_callbacks
600637 exception_callbacks
601638 func
602639 *args
@@ -659,6 +696,7 @@ def new_transaction(
659696 cursor = conn .cursor (
660697 txn_name = name ,
661698 after_callbacks = after_callbacks ,
699+ async_after_callbacks = async_after_callbacks ,
662700 exception_callbacks = exception_callbacks ,
663701 )
664702 try :
@@ -798,6 +836,7 @@ async def runInteraction(
798836
799837 async def _runInteraction () -> R :
800838 after_callbacks : List [_CallbackListEntry ] = []
839+ async_after_callbacks : List [_AsyncCallbackListEntry ] = []
801840 exception_callbacks : List [_CallbackListEntry ] = []
802841
803842 if not current_context ():
@@ -809,6 +848,7 @@ async def _runInteraction() -> R:
809848 self .new_transaction ,
810849 desc ,
811850 after_callbacks ,
851+ async_after_callbacks ,
812852 exception_callbacks ,
813853 func ,
814854 * args ,
@@ -817,15 +857,17 @@ async def _runInteraction() -> R:
817857 ** kwargs ,
818858 )
819859
860+ # We order these assuming that async functions call out to external
861+ # systems (e.g. to invalidate a cache) and the sync functions make these
862+ # changes on any local in-memory caches/similar, and thus must be second.
863+ for async_callback , async_args , async_kwargs in async_after_callbacks :
864+ await async_callback (* async_args , ** async_kwargs )
820865 for after_callback , after_args , after_kwargs in after_callbacks :
821- await maybe_awaitable (after_callback (* after_args , ** after_kwargs ))
822-
866+ after_callback (* after_args , ** after_kwargs )
823867 return cast (R , result )
824868 except Exception :
825869 for exception_callback , after_args , after_kwargs in exception_callbacks :
826- await maybe_awaitable (
827- exception_callback (* after_args , ** after_kwargs )
828- )
870+ exception_callback (* after_args , ** after_kwargs )
829871 raise
830872
831873 # To handle cancellation, we ensure that `after_callback`s and
0 commit comments