55# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66from __future__ import annotations
77
8- from typing import Any , Awaitable , Callable , Type , TypeVar
8+ from typing import Any , Awaitable , Callable , NamedTuple , Optional , Type , TypeVar
99from abc import ABC , abstractmethod
10- from contextlib import suppress
1110from enum import Enum , Flag , auto
1211import asyncio
12+ import itertools
1313import time
1414
1515from mautrix .errors import MUnknownToken
2525 Filter ,
2626 FilterID ,
2727 GenericEvent ,
28- MessageEvent ,
2928 PresenceState ,
3029 SerializerError ,
3130 StateEvent ,
@@ -79,13 +78,18 @@ class InternalEventType(Enum):
7978 DEVICE_OTK_COUNT = auto ()
8079
8180
81+ class EventHandlerProps (NamedTuple ):
82+ wait_sync : bool
83+ sync_stream : Optional [SyncStream ]
84+
85+
8286class Syncer (ABC ):
8387 loop : asyncio .AbstractEventLoop
8488 log : TraceLogger
8589 mxid : UserID
8690
87- global_event_handlers : list [ tuple [ EventHandler , bool ] ]
88- event_handlers : dict [EventType | InternalEventType , list [ tuple [ EventHandler , bool ] ]]
91+ global_event_handlers : dict [ EventHandler , EventHandlerProps ]
92+ event_handlers : dict [EventType | InternalEventType , dict [ EventHandler , EventHandlerProps ]]
8993 dispatchers : dict [Type [dispatcher .Dispatcher ], dispatcher .Dispatcher ]
9094 syncing_task : asyncio .Task | None
9195 ignore_initial_sync : bool
@@ -95,7 +99,7 @@ class Syncer(ABC):
9599 sync_store : SyncStore
96100
97101 def __init__ (self , sync_store : SyncStore ) -> None :
98- self .global_event_handlers = []
102+ self .global_event_handlers = {}
99103 self .event_handlers = {}
100104 self .dispatchers = {}
101105 self .syncing_task = None
@@ -158,6 +162,7 @@ def add_event_handler(
158162 event_type : InternalEventType | EventType ,
159163 handler : EventHandler ,
160164 wait_sync : bool = False ,
165+ sync_stream : Optional [SyncStream ] = None ,
161166 ) -> None :
162167 """
163168 Add a new event handler.
@@ -167,13 +172,15 @@ def add_event_handler(
167172 event types.
168173 handler: The handler function to add.
169174 wait_sync: Whether or not the handler should be awaited before the next sync request.
175+ sync_stream: The sync streams to listen to. Defaults to all.
170176 """
171177 if not isinstance (event_type , (EventType , InternalEventType )):
172178 raise ValueError ("Invalid event type" )
179+ props = EventHandlerProps (wait_sync = wait_sync , sync_stream = sync_stream )
173180 if event_type == EventType .ALL :
174- self .global_event_handlers . append (( handler , wait_sync ))
181+ self .global_event_handlers [ handler ] = props
175182 else :
176- self .event_handlers .setdefault (event_type , []). append (( handler , wait_sync ))
183+ self .event_handlers .setdefault (event_type , {})[ handler ] = props
177184
178185 def remove_event_handler (
179186 self , event_type : EventType | InternalEventType , handler : EventHandler
@@ -197,11 +204,7 @@ def remove_event_handler(
197204 # No handlers for this event type registered
198205 return
199206
200- # FIXME this is a bit hacky
201- with suppress (ValueError ):
202- handler_list .remove ((handler , True ))
203- with suppress (ValueError ):
204- handler_list .remove ((handler , False ))
207+ handler_list .pop (handler , None )
205208
206209 if len (handler_list ) == 0 and event_type != EventType .ALL :
207210 del self .event_handlers [event_type ]
@@ -229,7 +232,9 @@ def dispatch_event(self, event: Event | None, source: SyncStream) -> list[asynci
229232 else :
230233 event .type = event .type .with_class (EventType .Class .MESSAGE )
231234 setattr (event , "source" , source )
232- return self .dispatch_manual_event (event .type , event , include_global_handlers = True )
235+ return self .dispatch_manual_event (
236+ event .type , event , include_global_handlers = True , source = source
237+ )
233238
234239 async def _catch_errors (self , handler : EventHandler , data : Any ) -> None :
235240 try :
@@ -243,13 +248,22 @@ def dispatch_manual_event(
243248 data : Any ,
244249 include_global_handlers : bool = False ,
245250 force_synchronous : bool = False ,
251+ source : Optional [SyncStream ] = None ,
246252 ) -> list [asyncio .Task ]:
247- handlers = self .event_handlers .get (event_type , [] )
253+ handlers = self .event_handlers .get (event_type , {}). items ( )
248254 if include_global_handlers :
249- handlers = self .global_event_handlers + handlers
255+ handlers = itertools . chain ( self .global_event_handlers . items (), handlers )
250256 tasks = []
251- for handler , wait_sync in handlers :
252- if force_synchronous or wait_sync :
257+ if source is None :
258+ source = getattr (data , "source" , None )
259+ for handler , props in handlers :
260+ if (
261+ props .sync_stream is not None
262+ and source is not None
263+ and not props .sync_stream & source
264+ ):
265+ continue
266+ if force_synchronous or props .wait_sync :
253267 tasks .append (asyncio .create_task (self ._catch_errors (handler , data )))
254268 else :
255269 background_task .create (self ._catch_errors (handler , data ))
@@ -263,6 +277,7 @@ async def run_internal_event(
263277 event_type ,
264278 custom_type if custom_type is not None else kwargs ,
265279 include_global_handlers = False ,
280+ source = SyncStream .INTERNAL ,
266281 )
267282 await asyncio .gather (* tasks )
268283
@@ -274,6 +289,7 @@ def dispatch_internal_event(
274289 event_type ,
275290 custom_type if custom_type is not None else kwargs ,
276291 include_global_handlers = False ,
292+ source = SyncStream .INTERNAL ,
277293 )
278294
279295 def _try_deserialize (self , type : Type [T ], data : JSON ) -> T | GenericEvent :
0 commit comments