Skip to content

Commit ba883ea

Browse files
committed
Allow filtering event handlers by sync source
1 parent 970f371 commit ba883ea

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

mautrix/client/syncer.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import Any, Awaitable, Callable, Type, TypeVar
8+
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Type, TypeVar
99
from abc import ABC, abstractmethod
10-
from contextlib import suppress
1110
from enum import Enum, Flag, auto
1211
import asyncio
12+
import itertools
1313
import time
1414

1515
from mautrix.errors import MUnknownToken
@@ -25,7 +25,6 @@
2525
Filter,
2626
FilterID,
2727
GenericEvent,
28-
MessageEvent,
2928
PresenceState,
3029
SerializerError,
3130
StateEvent,
@@ -79,13 +78,18 @@ class InternalEventType(Enum):
7978
DEVICE_OTK_COUNT = auto()
8079

8180

81+
class EventHandlerProps(NamedTuple):
82+
wait_sync: bool
83+
sync_stream: Optional[SyncStream]
84+
85+
8286
class Syncer(ABC):
8387
loop: asyncio.AbstractEventLoop
8488
log: TraceLogger
8589
mxid: UserID
8690

87-
global_event_handlers: list[tuple[EventHandler, bool]]
88-
event_handlers: dict[EventType | InternalEventType, list[tuple[EventHandler, bool]]]
91+
global_event_handlers: dict[EventHandler, EventHandlerProps]
92+
event_handlers: dict[EventType | InternalEventType, dict[EventHandler, EventHandlerProps]]
8993
dispatchers: dict[Type[dispatcher.Dispatcher], dispatcher.Dispatcher]
9094
syncing_task: asyncio.Task | None
9195
ignore_initial_sync: bool
@@ -95,7 +99,7 @@ class Syncer(ABC):
9599
sync_store: SyncStore
96100

97101
def __init__(self, sync_store: SyncStore) -> None:
98-
self.global_event_handlers = []
102+
self.global_event_handlers = {}
99103
self.event_handlers = {}
100104
self.dispatchers = {}
101105
self.syncing_task = None
@@ -158,6 +162,7 @@ def add_event_handler(
158162
event_type: InternalEventType | EventType,
159163
handler: EventHandler,
160164
wait_sync: bool = False,
165+
sync_stream: Optional[SyncStream] = None,
161166
) -> None:
162167
"""
163168
Add a new event handler.
@@ -167,13 +172,15 @@ def add_event_handler(
167172
event types.
168173
handler: The handler function to add.
169174
wait_sync: Whether or not the handler should be awaited before the next sync request.
175+
sync_stream: The sync streams to listen to. Defaults to all.
170176
"""
171177
if not isinstance(event_type, (EventType, InternalEventType)):
172178
raise ValueError("Invalid event type")
179+
props = EventHandlerProps(wait_sync=wait_sync, sync_stream=sync_stream)
173180
if event_type == EventType.ALL:
174-
self.global_event_handlers.append((handler, wait_sync))
181+
self.global_event_handlers[handler] = props
175182
else:
176-
self.event_handlers.setdefault(event_type, []).append((handler, wait_sync))
183+
self.event_handlers.setdefault(event_type, {})[handler] = props
177184

178185
def remove_event_handler(
179186
self, event_type: EventType | InternalEventType, handler: EventHandler
@@ -197,11 +204,7 @@ def remove_event_handler(
197204
# No handlers for this event type registered
198205
return
199206

200-
# FIXME this is a bit hacky
201-
with suppress(ValueError):
202-
handler_list.remove((handler, True))
203-
with suppress(ValueError):
204-
handler_list.remove((handler, False))
207+
handler_list.pop(handler, None)
205208

206209
if len(handler_list) == 0 and event_type != EventType.ALL:
207210
del self.event_handlers[event_type]
@@ -229,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci
229232
else:
230233
event.type = event.type.with_class(EventType.Class.MESSAGE)
231234
setattr(event, "source", source)
232-
return self.dispatch_manual_event(event.type, event, include_global_handlers=True)
235+
return self.dispatch_manual_event(
236+
event.type, event, include_global_handlers=True, source=source
237+
)
233238

234239
async def _catch_errors(self, handler: EventHandler, data: Any) -> None:
235240
try:
@@ -243,13 +248,22 @@ def dispatch_manual_event(
243248
data: Any,
244249
include_global_handlers: bool = False,
245250
force_synchronous: bool = False,
251+
source: Optional[SyncStream] = None,
246252
) -> list[asyncio.Task]:
247-
handlers = self.event_handlers.get(event_type, [])
253+
handlers = self.event_handlers.get(event_type, {}).items()
248254
if include_global_handlers:
249-
handlers = self.global_event_handlers + handlers
255+
handlers = itertools.chain(self.global_event_handlers.items(), handlers)
250256
tasks = []
251-
for handler, wait_sync in handlers:
252-
if force_synchronous or wait_sync:
257+
if source is None:
258+
source = getattr(data, "source", None)
259+
for handler, props in handlers:
260+
if (
261+
props.sync_stream is not None
262+
and source is not None
263+
and not props.sync_stream & source
264+
):
265+
continue
266+
if force_synchronous or props.wait_sync:
253267
tasks.append(asyncio.create_task(self._catch_errors(handler, data)))
254268
else:
255269
background_task.create(self._catch_errors(handler, data))
@@ -263,6 +277,7 @@ async def run_internal_event(
263277
event_type,
264278
custom_type if custom_type is not None else kwargs,
265279
include_global_handlers=False,
280+
source=SyncStream.INTERNAL,
266281
)
267282
await asyncio.gather(*tasks)
268283

@@ -274,6 +289,7 @@ def dispatch_internal_event(
274289
event_type,
275290
custom_type if custom_type is not None else kwargs,
276291
include_global_handlers=False,
292+
source=SyncStream.INTERNAL,
277293
)
278294

279295
def _try_deserialize(self, type: Type[T], data: JSON) -> T | GenericEvent:

0 commit comments

Comments
 (0)