Skip to content

Commit aba538b

Browse files
committed
tasks
1 parent daab987 commit aba538b

File tree

3 files changed

+58
-15
lines changed

3 files changed

+58
-15
lines changed

discord/client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from .soundboard import SoundboardSound
7676
from .ui.item import Item
7777
from .voice_client import VoiceProtocol
78+
from .ext.tasks import Loop as TaskLoop
7879

7980
__all__ = ("Client",)
8081

@@ -119,6 +120,27 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
119120
loop.close()
120121

121122

123+
class LoopTaskSet:
124+
def __init__(self) -> None:
125+
self.tasks: set[TaskLoop] = set()
126+
self.client: Client | None = None
127+
128+
def add_loop(self, loop: TaskLoop) -> None:
129+
if self.client is not None:
130+
running = asyncio.get_running_loop()
131+
loop.loop = running
132+
loop.start()
133+
else:
134+
self.tasks.add(loop)
135+
136+
def start(self, client: Client) -> None:
137+
self.client = client
138+
for task in self.tasks:
139+
loop = client.loop
140+
task.loop = loop
141+
task.start()
142+
143+
122144
class Client:
123145
r"""Represents a client connection that connects to Discord.
124146
This class is used to interact with the Discord WebSocket and API.
@@ -227,6 +249,8 @@ class Client:
227249
The event loop that the client uses for asynchronous operations.
228250
"""
229251

252+
_pending_loops = LoopTaskSet()
253+
230254
def __init__(
231255
self,
232256
*,
@@ -297,6 +321,7 @@ async def __aenter__(self) -> Client:
297321
# Maybe handle different system event loop policies?
298322
self._loop = asyncio.new_event_loop()
299323

324+
self._pending_loops.start(self)
300325
self.http.loop = self.loop
301326
self._connection.loop = self.loop
302327

@@ -506,7 +531,6 @@ def _schedule_event(
506531
return task
507532

508533
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
509-
_log.debug("Dispatching event %s", event)
510534
method = f"on_{event}"
511535

512536
listeners = self._listeners.get(event)

discord/ext/tasks/__init__.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import sys
3333
import traceback
3434
from collections.abc import Sequence
35-
from typing import Any, Awaitable, Callable, Generic, TypeVar, cast
35+
from typing import Any, Awaitable, Callable, Generic, TypeVar
3636

3737
import aiohttp
3838

@@ -120,18 +120,19 @@ def __init__(
120120
count: int | None,
121121
reconnect: bool,
122122
loop: asyncio.AbstractEventLoop | None,
123+
create_loop: bool,
123124
name: str | None,
124125
) -> None:
125126
self.coro: LF = coro
126127
self.reconnect: bool = reconnect
127128

128-
if loop is None:
129+
if create_loop is True and loop is None:
129130
try:
130131
loop = asyncio.get_running_loop()
131132
except RuntimeError:
132133
loop = asyncio.new_event_loop()
133134

134-
self.loop = loop
135+
self.loop: asyncio.AbstractEventLoop | None = loop
135136

136137
self.name: str = f'pycord-ext-task ({id(self):#x}): {coro.__qualname__}' if name in (None, MISSING) else name
137138
self.count: int | None = count
@@ -146,6 +147,7 @@ def __init__(
146147
aiohttp.ClientError,
147148
asyncio.TimeoutError,
148149
)
150+
self._create_loop = create_loop
149151

150152
self._before_loop = None
151153
self._after_loop = None
@@ -168,6 +170,9 @@ def __init__(
168170
f"Expected coroutine function, not {type(self.coro).__name__!r}."
169171
)
170172

173+
if loop is None and not create_loop:
174+
discord.Client._pending_loops.add_loop(self)
175+
171176
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
172177
coro = getattr(self, f"_{name}")
173178
if coro is None:
@@ -280,6 +285,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
280285
reconnect=self.reconnect,
281286
name=self.name,
282287
loop=self.loop,
288+
create_loop=self._create_loop,
283289
)
284290
copy._injected = obj
285291
copy._before_loop = self._before_loop
@@ -365,7 +371,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
365371

366372
return await self.coro(*args, **kwargs)
367373

368-
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
374+
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None] | None:
369375
r"""Starts the internal task in the event loop.
370376
371377
Parameters
@@ -386,13 +392,21 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
386392
The task that has been created.
387393
"""
388394

395+
if self.loop is None:
396+
_log.warning(
397+
f"The task {self.name} has been set to be bound to a discord.Client instance, and will start running automatically "
398+
"when the client starts. If you want this task to be executed without it being bound to a discord.Client, "
399+
"set the create_loop parameter in the decorator to True, and don't forget to set the client.loop to the loop.loop"
400+
)
401+
return None
402+
389403
if self._task is not MISSING and not self._task.done():
390404
raise RuntimeError("Task is already launched and is not completed.")
391405

392406
if self._injected is not None:
393407
args = (self._injected, *args)
394408

395-
self._task = self.loop.create_task(self._loop(*args, **kwargs), name=self.name)
409+
self._task = asyncio.ensure_future(self.loop.create_task(self._loop(*args, **kwargs), name=self.name))
396410
return self._task
397411

398412
def stop(self) -> None:
@@ -760,15 +774,9 @@ def change_interval(
760774
self._time = self._get_time_parameter(time)
761775
self._sleep = self._seconds = self._minutes = self._hours = MISSING
762776

763-
if self.is_running() and not (
764-
self._before_loop_running or self._after_loop_running
765-
):
766-
if self._time is not MISSING:
767-
# prepare the next time index starting from after the last iteration
768-
self._prepare_time_index(now=self._last_iteration)
769-
777+
if self.is_running() and self._last_iteration is not MISSING:
770778
self._next_iteration = self._get_next_sleep_time()
771-
if not self._handle.done():
779+
if self._handle and not self._handle.done():
772780
# the loop is sleeping, recalculate based on new interval
773781
self._handle.recalculate(self._next_iteration)
774782

@@ -783,6 +791,7 @@ def loop(
783791
reconnect: bool = True,
784792
loop: asyncio.AbstractEventLoop | None = None,
785793
name: str | None = MISSING,
794+
create_loop: bool = False,
786795
) -> Callable[[LF], Loop[LF]]:
787796
"""A decorator that schedules a task in the background for you with
788797
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -820,6 +829,11 @@ def loop(
820829
name: Optional[:class:`str`]
821830
The name to create the task with, defaults to ``None``.
822831
832+
.. versionadded:: 2.7
833+
create_loop: :class:`bool`
834+
Whether this task should create their own event loop to start running it
835+
without a client bound to it.
836+
823837
.. versionadded:: 2.7
824838
825839
Raises
@@ -842,6 +856,7 @@ def decorator(func: LF) -> Loop[LF]:
842856
reconnect=reconnect,
843857
name=name,
844858
loop=loop,
859+
create_loop=create_loop,
845860
)
846861

847862
return decorator

discord/state.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(
178178
if self.max_messages is not None and self.max_messages <= 0:
179179
self.max_messages = 1000
180180

181-
self.dispatch: Callable = dispatch
181+
self._dispatch: Callable = dispatch
182182
self.handlers: dict[str, Callable] = handlers
183183
self.hooks: dict[str, Callable] = hooks
184184
self.shard_count: int | None = None
@@ -263,6 +263,10 @@ def __init__(
263263

264264
self.clear()
265265

266+
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any:
267+
_log.debug('Dispatching event %s', event)
268+
return self._dispatch(event, *args, **kwargs)
269+
266270
def clear(self, *, views: bool = True) -> None:
267271
self.user: ClientUser | None = None
268272
# Originally, this code used WeakValueDictionary to maintain references to the

0 commit comments

Comments
 (0)