Skip to content

Commit 4c7503d

Browse files
committed
Add wrapper for creating background tasks
The wrapper ensures that the reference to the task is not lost before it completes, and also logs uncaught errors.
1 parent 98140aa commit 4c7503d

File tree

10 files changed

+76
-15
lines changed

10 files changed

+76
-15
lines changed

mautrix/appservice/as_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SerializerError,
2828
UserID,
2929
)
30+
from mautrix.util import background_task
3031

3132
HandlerFunc = Callable[[Event], Awaitable]
3233

@@ -314,7 +315,7 @@ async def try_handle(handler_func: HandlerFunc):
314315

315316
for handler in self.event_handlers:
316317
# TODO add option to handle events synchronously
317-
asyncio.create_task(try_handle(handler))
318+
background_task.create(try_handle(handler))
318319

319320
def matrix_event_handler(self, func: HandlerFunc) -> HandlerFunc:
320321
self.event_handlers.append(func)

mautrix/bridge/custom_puppet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
SyncToken,
4242
UserID,
4343
)
44+
from mautrix.util import background_task
4445

4546
from .. import bridge as br
4647

@@ -409,7 +410,7 @@ def _handle_sync(self, sync_resp: dict) -> None:
409410

410411
# Deserialize and handle all events
411412
for event in chain(ephemeral_events, presence_events):
412-
asyncio.create_task(self.mx.try_handle_sync_event(Event.deserialize(event)))
413+
background_task.create(self.mx.try_handle_sync_event(Event.deserialize(event)))
413414

414415
async def _try_sync(self) -> None:
415416
try:

mautrix/bridge/matrix.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
Version,
5757
VersionsResponse,
5858
)
59-
from mautrix.util import markdown
59+
from mautrix.util import background_task, markdown
6060
from mautrix.util.logging import TraceLogger
6161
from mautrix.util.message_send_checkpoint import (
6262
CHECKPOINT_TYPES,
@@ -382,7 +382,7 @@ async def handle_puppet_nonportal_invite(
382382
members = await intent.get_room_members(room_id)
383383
except MatrixError:
384384
self.log.exception(f"Failed to get state after joining {room_id} as {intent.mxid}")
385-
asyncio.create_task(intent.leave_room(room_id, reason="Internal error"))
385+
background_task.create(intent.leave_room(room_id, reason="Internal error"))
386386
return
387387
if create_evt.type == RoomType.SPACE:
388388
await self.handle_puppet_space_invite(room_id, puppet, invited_by, evt)
@@ -798,7 +798,7 @@ async def _handle_encrypted_wait(
798798
f"Couldn't find session {err.session_id} trying to decrypt {evt.event_id},"
799799
" waiting even longer"
800800
)
801-
asyncio.create_task(
801+
background_task.create(
802802
self.e2ee.crypto.request_room_key(
803803
evt.room_id,
804804
evt.content.sender_key,
@@ -875,7 +875,7 @@ def _send_message_checkpoint(
875875
info=str(err) if err else None,
876876
retry_num=retry_num,
877877
)
878-
asyncio.create_task(checkpoint.send(endpoint, self.az.as_token, self.log))
878+
background_task.create(checkpoint.send(endpoint, self.az.as_token, self.log))
879879

880880
allowed_event_classes: tuple[type, ...] = (
881881
MessageEvent,

mautrix/bridge/portal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TextMessageEventContent,
3131
UserID,
3232
)
33+
from mautrix.util import background_task
3334
from mautrix.util.logging import TraceLogger
3435
from mautrix.util.simple_lock import SimpleLock
3536

@@ -402,7 +403,7 @@ async def restart_scheduled_disappearing(cls) -> None:
402403
for msg in msgs:
403404
portal = await cls.bridge.get_portal(msg.room_id)
404405
if portal and portal.mxid:
405-
asyncio.create_task(portal._disappear_event(msg))
406+
background_task.create(portal._disappear_event(msg))
406407
else:
407408
await msg.delete()
408409

@@ -418,7 +419,7 @@ async def schedule_disappearing(self) -> None:
418419
for msg in msgs:
419420
msg.start_timer()
420421
await msg.update()
421-
asyncio.create_task(self._disappear_event(msg))
422+
background_task.create(self._disappear_event(msg))
422423

423424
async def _send_message(
424425
self,
@@ -431,7 +432,7 @@ async def _send_message(
431432
event_type, content = await self.matrix.e2ee.encrypt(self.mxid, event_type, content)
432433
event_id = await intent.send_message_event(self.mxid, event_type, content, **kwargs)
433434
if intent.api.is_real_user:
434-
asyncio.create_task(intent.mark_read(self.mxid, event_id))
435+
background_task.create(intent.mark_read(self.mxid, event_id))
435436
return event_id
436437

437438
@property

mautrix/bridge/user.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mautrix.appservice import AppService
1717
from mautrix.errors import MNotFound
1818
from mautrix.types import EventID, EventType, Membership, MessageType, RoomID, UserID
19+
from mautrix.util import background_task
1920
from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
2021
from mautrix.util.logging import TraceLogger
2122
from mautrix.util.message_send_checkpoint import (
@@ -244,7 +245,7 @@ def send_remote_checkpoint(
244245
"""
245246
if not self.bridge.config["homeserver.message_send_checkpoint_endpoint"]:
246247
return WrappedTask(task=None)
247-
task = asyncio.create_task(
248+
task = background_task.create(
248249
MessageSendCheckpoint(
249250
event_id=event_id,
250251
room_id=room_id,

mautrix/client/api/modules/media_repository.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MXOpenGraph,
2121
SerializerError,
2222
)
23+
from mautrix.util import background_task
2324
from mautrix.util.async_body import async_iter_bytes
2425
from mautrix.util.opt_prometheus import Histogram
2526

@@ -157,7 +158,7 @@ async def _try_upload():
157158
except Exception as e:
158159
self.log.error(f"Failed to upload {mxc}: {type(e).__name__}: {e}")
159160

160-
asyncio.create_task(_try_upload())
161+
background_task.create(_try_upload())
161162
return mxc
162163
else:
163164
with self._observe_upload_time(size):

mautrix/crypto/decrypt_olm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ToDeviceEvent,
2020
UserID,
2121
)
22+
from mautrix.util import background_task
2223

2324
from .base import BaseOlmMachine
2425
from .sessions import Session
@@ -74,19 +75,19 @@ async def _decrypt_olm_ciphertext(
7475
f"Found matching session yet decryption failed for sender {sender}"
7576
f" with key {sender_key}"
7677
)
77-
asyncio.create_task(self._unwedge_session(sender, sender_key))
78+
background_task.create(self._unwedge_session(sender, sender_key))
7879
raise
7980

8081
if not plaintext:
8182
if message.type != OlmMsgType.PREKEY:
82-
asyncio.create_task(self._unwedge_session(sender, sender_key))
83+
background_task.create(self._unwedge_session(sender, sender_key))
8384
raise DecryptionError("Decryption failed for normal message")
8485

8586
self.log.trace(f"Trying to create inbound session for {sender}/{sender_key}")
8687
try:
8788
session = await self._create_inbound_session(sender_key, message.body)
8889
except olm.OlmSessionError as e:
89-
asyncio.create_task(self._unwedge_session(sender, sender_key))
90+
background_task.create(self._unwedge_session(sender, sender_key))
9091
raise DecryptionError("Failed to create new session from prekey message") from e
9192
self.log.debug(
9293
f"Created inbound session {session.id} for {sender} (sender key: {sender_key})"

mautrix/crypto/machine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TrustState,
2525
UserID,
2626
)
27+
from mautrix.util import background_task
2728
from mautrix.util.logging import TraceLogger
2829

2930
from .account import OlmAccount
@@ -109,7 +110,7 @@ async def handle_as_otk_counts(
109110
self.log.warning(f"Got OTK count for unknown device {user_id}/{device_id}")
110111

111112
async def handle_as_device_lists(self, device_lists: DeviceLists) -> None:
112-
asyncio.create_task(self.handle_device_lists(device_lists))
113+
background_task.create(self.handle_device_lists(device_lists))
113114

114115
async def handle_as_to_device_event(self, evt: ASToDeviceEvent) -> None:
115116
if evt.to_user_id != self.client.mxid or evt.to_device_id != self.client.device_id:

mautrix/util/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# File modules
99
"async_body",
1010
"async_getter_lock",
11+
"background_task",
1112
"bridge_state",
1213
"color_log",
1314
"ffmpeg",

mautrix/util/background_task.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) 2023 Tulir Asokan
2+
#
3+
# This Source Code Form is subject to the terms of the Mozilla Public
4+
# License, v. 2.0. If a copy of the MPL was not distributed with this
5+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
from __future__ import annotations
7+
8+
from typing import Coroutine
9+
import asyncio
10+
import logging
11+
12+
_tasks = set()
13+
log = logging.getLogger("mau.background_task")
14+
15+
16+
async def catch(coro: Coroutine, caller: str) -> None:
17+
try:
18+
await coro
19+
except Exception:
20+
log.exception(f"Uncaught error in background task (created in {caller})")
21+
22+
23+
# Logger.findCaller finds the 3rd stack frame, so add an intermediate function
24+
# to get the caller of create().
25+
def _find_caller() -> tuple[str, int, str, None]:
26+
return log.findCaller()
27+
28+
29+
def create(coro: Coroutine, *, name: str | None = None, catch_errors: bool = True) -> asyncio.Task:
30+
"""
31+
Create a background asyncio task safely, ensuring a reference is kept until the task completes.
32+
It also catches and logs uncaught errors (unless disabled via the parameter).
33+
34+
Args:
35+
coro: The coroutine to wrap in a task and execute.
36+
name: An optional name for the created task.
37+
catch_errors: Should the task be wrapped in a try-except block to log any uncaught errors?
38+
39+
Returns:
40+
An asyncio Task object wrapping the given coroutine.
41+
"""
42+
if catch_errors:
43+
try:
44+
file_name, line_number, function_name, _ = _find_caller()
45+
caller = f"{function_name} at {file_name}:{line_number}"
46+
except ValueError:
47+
caller = "unknown function"
48+
task = asyncio.create_task(catch(coro, caller), name=name)
49+
else:
50+
task = asyncio.create_task(coro, name=name)
51+
_tasks.add(task)
52+
task.add_done_callback(_tasks.discard)
53+
return task

0 commit comments

Comments
 (0)