Skip to content

Commit 6a6c803

Browse files
committed
feat(common): add support for waiting for a dispatch to complete.
This also waits for error handlers. Conflicts with #253 is to be expected, sorry!
1 parent d03c161 commit 6a6c803

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

nextcore/common/dispatcher.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from asyncio import CancelledError, Future, create_task
24+
from asyncio import CancelledError, Future, Task, create_task, gather
2525
from collections import defaultdict
2626
from logging import getLogger
2727
from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, cast, overload
@@ -423,7 +423,7 @@ async def wait_for(
423423
return result # type: ignore [return-value]
424424

425425
# Dispatching
426-
async def dispatch(self, event_name: EventNameT, *args: Any) -> None:
426+
async def dispatch(self, event_name: EventNameT, *args: Any, wait: bool = False) -> None:
427427
"""Dispatch a event
428428
429429
**Example usage:**
@@ -438,25 +438,35 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None:
438438
The event name to dispatch to.
439439
args:
440440
The event arguments. This will be passed to the listeners.
441+
wait:
442+
Wait for all listeners to complete.
441443
"""
442444
logger.debug("Dispatching event %s", event_name)
443445

446+
tasks: list[Task[None]] = []
447+
444448
# Event handlers
445449
# Tasks are used here as some event handler/check might take a long time.
446450
for handler in self._global_event_handlers:
447451
logger.debug("Dispatching to a global handler")
448-
create_task(self._run_global_event_handler(handler, event_name, *args))
452+
tasks.append(create_task(self._run_global_event_handler(handler, event_name, *args)))
449453
for handler in self._event_handlers.get(event_name, []):
450454
logger.debug("Dispatching to a local handler")
451-
create_task(self._run_event_handler(handler, event_name, *args))
455+
tasks.append(create_task(self._run_event_handler(handler, event_name, *args)))
452456

453457
# Wait for handlers
454458
for check, future in self._wait_for_handlers.get(event_name, []):
455459
logger.debug("Dispatching to a wait_for handler")
456-
create_task(self._run_wait_for_handler(check, future, event_name, *args))
460+
tasks.append(create_task(self._run_wait_for_handler(check, future, event_name, *args)))
457461
for check, future in self._global_wait_for_handlers:
458462
logger.debug("Dispatching to a global wait_for handler")
459-
create_task(self._run_global_wait_for_handler(check, future, event_name, *args))
463+
tasks.append(create_task(self._run_global_wait_for_handler(check, future, event_name, *args)))
464+
465+
# Optional waiting
466+
logger.debug("Dispatching via %s tasks", len(tasks))
467+
468+
if wait:
469+
await gather(*tasks)
460470

461471
async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None:
462472
"""Run event with exception handlers"""

tests/common/test_dispatcher.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from __future__ import annotations
22

3-
from asyncio import Future
4-
from asyncio import TimeoutError as AsyncioTimeoutError
5-
from asyncio import create_task, get_running_loop, wait_for
3+
from asyncio import TimeoutError as AsyncioTimeoutError, Future, sleep, create_task, get_running_loop, wait_for
64
from typing import TYPE_CHECKING
5+
from tests.utils import match_time
76

87
from pytest import mark, raises
98

@@ -192,3 +191,15 @@ def false_callback(event: str | None = None) -> bool:
192191
assert error_count == 0, "Logged errors where present"
193192

194193
dispatcher.close()
194+
195+
@mark.asyncio
196+
@match_time(1, 0.1)
197+
async def test_dispatch_wait():
198+
dispatcher: Dispatcher[str] = Dispatcher()
199+
200+
async def handler():
201+
await sleep(1)
202+
203+
dispatcher.add_listener(handler, "test")
204+
205+
await dispatcher.dispatch("test", wait=True)

0 commit comments

Comments
 (0)