diff --git a/CHANGELOG.md b/CHANGELOG.md index 3098be39c9..1a380a6494 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2905](https://github.com/Pycord-Development/pycord/pull/2905)) - `view=None` in various methods causing an AttributeError. ([#2915](https://github.com/Pycord-Development/pycord/pull/2915)) +- Fixed Async I/O errors that could be raised when using `Client.run`. + ([#2645](https://github.com/Pycord-Development/pycord/pull/2645)) - `View.message` being `None` when it had not been interacted with yet. ([#2916](https://github.com/Pycord-Development/pycord/pull/2916)) - Fixed a crash when processing message edit events while message cache was disabled. diff --git a/discord/bot.py b/discord/bot.py index 7dd246afe3..b5ee86422d 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -1386,7 +1386,7 @@ def before_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -1418,7 +1418,7 @@ def after_invoke(self, coro): The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro diff --git a/discord/client.py b/discord/client.py index ed6cdb8991..fe17e464e9 100644 --- a/discord/client.py +++ b/discord/client.py @@ -26,8 +26,8 @@ from __future__ import annotations import asyncio +import inspect import logging -import signal import sys import traceback from types import TracebackType @@ -69,6 +69,7 @@ if TYPE_CHECKING: from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .channel import DMChannel + from .ext.tasks import Loop as TaskLoop from .interactions import Interaction from .member import Member from .message import Message @@ -120,12 +121,39 @@ def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: loop.close() +class LoopTaskSet: + def __init__(self) -> None: + self.tasks: set[TaskLoop] = set() + self.client: Client | None = None + + def add_loop(self, loop: TaskLoop) -> None: + if self.client is not None: + running = asyncio.get_running_loop() + loop.loop = running + loop.start() + else: + self.tasks.add(loop) + + def start(self, client: Client) -> None: + self.client = client + for task in self.tasks: + loop = client.loop + task.loop = loop + task.start() + + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. A number of options can be passed to the :class:`Client`. + .. container:: operations + + .. describe:: async with x + + Asynchronously initializes the client. + Parameters ----------- max_messages: Optional[:class:`int`] @@ -222,6 +250,8 @@ class Client: The event loop that the client uses for asynchronous operations. """ + _pending_loops = LoopTaskSet() + def __init__( self, *, @@ -230,9 +260,14 @@ def __init__( ): # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + pass + + self._loop: asyncio.AbstractEventLoop | None = loop self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = ( {} ) @@ -248,7 +283,7 @@ def __init__( proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, - loop=self.loop, + loop=self._loop, ) self._handlers: dict[str, Callable] = {"ready": self._handle_ready} @@ -260,11 +295,13 @@ def __init__( self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count - self._closed: bool = False + self._closed: asyncio.Event = asyncio.Event() + self._closing_task: asyncio.Lock = asyncio.Lock() self._ready: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self self._event_handlers: dict[str, list[Coro]] = {} + self._in_context: bool = False if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False @@ -274,12 +311,25 @@ def __init__( self._tasks = set() async def __aenter__(self) -> Client: - loop = asyncio.get_running_loop() - self.loop = loop - self.http.loop = loop - self._connection.loop = loop + self._in_context = True + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + # No event loop was found, this should not happen + # because entering on this context manager means a + # loop is already active, but we need to handle it + # anyways just to prevent future errors. + + # Maybe handle different system event loop policies? + self._loop = asyncio.new_event_loop() + + self._pending_loops.start(self) + self.http.loop = self.loop + self._connection.loop = self.loop self._ready = asyncio.Event() + self._closed = asyncio.Event() return self @@ -289,6 +339,7 @@ async def __aexit__( exc_v: BaseException | None, exc_tb: TracebackType | None, ) -> None: + self._in_context = False if not self.is_closed(): await self.close() @@ -305,13 +356,28 @@ def _get_state(self, **options: Any) -> ConnectionState: handlers=self._handlers, hooks=self._hooks, http=self.http, - loop=self.loop, + loop=self._loop, **options, ) def _handle_ready(self) -> None: self._ready.set() + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The event loop that the client uses for asynchronous operations.""" + if self._loop is None: + raise RuntimeError("loop is not set") + return self._loop + + @loop.setter + def loop(self, value: asyncio.AbstractEventLoop) -> None: + if not isinstance(value, asyncio.AbstractEventLoop): + raise TypeError( + f"expected a AbstractEventLoop object, got {value.__class__.__name__!r} instead" + ) + self._loop = value + @property def latency(self) -> float: """Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. If no websocket @@ -469,7 +535,6 @@ def _schedule_event( return task def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug("Dispatching event %s", event) method = f"on_{event}" listeners = self._listeners.get(event) @@ -764,23 +829,24 @@ async def close(self) -> None: Closes the connection to Discord. """ - if self._closed: - return + async with self._closing_task: + if self.is_closed(): + return - await self.http.close() - self._closed = True + await self.http.close() - for voice in self.voice_clients: - try: - await voice.disconnect(force=True) - except Exception: - # if an error happens during disconnects, disregard it. - pass + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # if an error happens during disconnects, disregard it. + pass - if self.ws is not None and self.ws.open: - await self.ws.close(code=1000) + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - self._ready.clear() + self._ready.clear() + self._closed.set() def clear(self) -> None: """Clears the internal state of the bot. @@ -789,7 +855,7 @@ def clear(self) -> None: and :meth:`is_ready` both return ``False`` along with the bot's internal cache cleared. """ - self._closed = False + self._closed.clear() self._ready.clear() self._connection.clear() self.http.recreate() @@ -804,10 +870,29 @@ async def start(self, token: str, *, reconnect: bool = True) -> None: TypeError An unexpected keyword argument was received. """ + if not self._in_context and self._loop is None: + # Update the loop to get the running one in case the one set is MISSING + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + + self._pending_loops.start(self) + self.http.loop = self.loop + self._connection.loop = self.loop + + self._ready = asyncio.Event() + self._closed = asyncio.Event() + await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args: Any, **kwargs: Any) -> None: + def run( + self, + token: str, + *, + reconnect: bool = True, + ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -818,12 +903,20 @@ def run(self, *args: Any, **kwargs: Any) -> None: Roughly Equivalent to: :: try: - loop.run_until_complete(start(*args, **kwargs)) + asyncio.run(start(token)) except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + return + + Parameters + ---------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + reconnect: :class:`bool` + If we should attempt reconnecting to the gateway, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). .. warning:: @@ -831,47 +924,36 @@ def run(self, *args: Any, **kwargs: Any) -> None: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.add_signal_handler(signal.SIGTERM, loop.stop) - except (NotImplementedError, RuntimeError): - pass async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() + async with self: + await self.start(token=token, reconnect=reconnect) - def stop_loop_on_completion(f): - loop.stop() + try: + run = self.loop.run_until_complete + requires_cleanup = True + except RuntimeError: + run = asyncio.run + requires_cleanup = False - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() - except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") + run(runner()) finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) + # Ensure the bot is closed + if not self.is_closed(): + self.loop.run_until_complete(self.close()) - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # asyncio.run automatically does the cleanup tasks, so if we use + # it we don't need to clean up the tasks. + if requires_cleanup: + _log.info("Cleaning up tasks.") + _cleanup_loop(self.loop) # properties def is_closed(self) -> bool: """Indicates if the WebSocket connection is closed.""" - return self._closed + return self._closed.is_set() @property def activity(self) -> ActivityTypes | None: @@ -1331,7 +1413,7 @@ async def my_message(message): pass if not name.startswith("on_"): raise ValueError("The 'name' parameter must start with 'on_'") - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Listeners must be coroutines") if name in self._event_handlers: @@ -1411,7 +1493,7 @@ def decorator(func: Coro) -> Coro: self.add_listener(func, name) return func - if asyncio.iscoroutinefunction(name): + if inspect.iscoroutinefunction(name): coro = name name = coro.__name__ return decorator(coro) @@ -1446,7 +1528,7 @@ async def on_ready(): print('Ready!') """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("event registered must be a coroutine function") setattr(self, coro.__name__, coro) diff --git a/discord/commands/core.py b/discord/commands/core.py index 90e7c5aa60..e30cccfe25 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -503,7 +503,7 @@ def error(self, coro): The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The error handler must be a coroutine.") self.on_error = coro @@ -532,7 +532,7 @@ def before_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -557,7 +557,7 @@ def after_invoke(self, coro): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro @@ -734,7 +734,7 @@ def __new__(cls, *args, **kwargs) -> SlashCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func @@ -1125,7 +1125,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): else: result = option.autocomplete(ctx) - if asyncio.iscoroutinefunction(option.autocomplete): + if inspect.iscoroutinefunction(option.autocomplete): result = await result choices = [ @@ -1653,7 +1653,7 @@ def __new__(cls, *args, **kwargs) -> ContextMenuCommand: def __init__(self, func: Callable, *args, **kwargs) -> None: super().__init__(func, **kwargs) - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") self.callback = func diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 1a0d8a09a2..a86634fcaf 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -325,7 +325,7 @@ def __init__( ), **kwargs: Any, ): - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") name = kwargs.get("name") or func.__name__ @@ -993,7 +993,7 @@ def error(self, coro: ErrorT) -> ErrorT: The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The error handler must be a coroutine.") self.on_error: Error = coro @@ -1027,7 +1027,7 @@ def before_invoke(self, coro: HookT) -> HookT: TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro @@ -1054,7 +1054,7 @@ def after_invoke(self, coro: HookT) -> HookT: TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): + if not inspect.iscoroutinefunction(coro): raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index af34cc6844..b7d2dbb4aa 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -28,10 +28,11 @@ import asyncio import datetime import inspect +import logging import sys import traceback from collections.abc import Sequence -from typing import Any, Awaitable, Callable, Generic, TypeVar, cast +from typing import Any, Awaitable, Callable, Generic, TypeVar import aiohttp @@ -43,26 +44,61 @@ T = TypeVar("T") _func = Callable[..., Awaitable[Any]] +_log = logging.getLogger(__name__) LF = TypeVar("LF", bound=_func) FT = TypeVar("FT", bound=_func) ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) +def is_ambiguous(dt: datetime.datetime) -> bool: + if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone): + return False + + before = dt.replace(fold=0) + after = dt.replace(fold=1) + + same_offset = before.utcoffset() == after.utcoffset() + same_dst = before.dst() == after.dst() + return not (same_offset and same_dst) + + +def is_imaginary(dt: datetime.datetime) -> bool: + if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone): + return False + + tz = dt.tzinfo + dt = dt.replace(tzinfo=None) + roundtrip = ( + dt.replace(tzinfo=tz) + .astimezone(datetime.timezone.utc) + .astimezone(tz) + .replace(tzinfo=None) + ) + return dt != roundtrip + + class SleepHandle: __slots__ = ("future", "loop", "handle") def __init__( self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop ) -> None: - self.loop = loop - self.future = future = loop.create_future() + self.loop: asyncio.AbstractEventLoop = loop + self.future: asyncio.Future[None] = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = loop.call_later(relative_delta, future.set_result, True) + self.handle = loop.call_later(relative_delta, self._safe_result, self.future) + + @staticmethod + def _safe_result(future: asyncio.Future) -> None: + if not future.done(): + future.set_result(None) def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + self.handle = self.loop.call_later( + relative_delta, self._safe_result, self.future + ) def wait(self) -> asyncio.Future[Any]: return self.future @@ -90,11 +126,26 @@ def __init__( time: datetime.time | Sequence[datetime.time], count: int | None, reconnect: bool, - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop | None, + create_loop: bool, + name: str | None, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect - self.loop: asyncio.AbstractEventLoop = loop + + if create_loop is True and loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + self.loop: asyncio.AbstractEventLoop | None = loop + + self.name: str = ( + f"pycord-ext-task ({id(self):#x}): {coro.__qualname__}" + if name in (None, MISSING) + else name + ) self.count: int | None = count self._current_loop = 0 self._handle: SleepHandle = MISSING @@ -107,6 +158,7 @@ def __init__( aiohttp.ClientError, asyncio.TimeoutError, ) + self._create_loop = create_loop self._before_loop = None self._after_loop = None @@ -129,6 +181,9 @@ def __init__( f"Expected coroutine function, not {type(self.coro).__name__!r}." ) + if loop is None and not create_loop: + discord.Client._pending_loops.add_loop(self) + async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: coro = getattr(self, f"_{name}") if coro is None: @@ -146,45 +201,69 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non setattr(self, f"_{name}_running", False) def _try_sleep_until(self, dt: datetime.datetime): - self._handle = SleepHandle(dt=dt, loop=self.loop) + self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop()) return self._handle.wait() + def _rel_time(self) -> bool: + return self._time is MISSING + + def _expl_time(self) -> bool: + return self._time is not MISSING + async def _loop(self, *args: Any, **kwargs: Any) -> None: backoff = ExponentialBackoff() await self._call_loop_function("before_loop") self._last_iteration_failed = False - if self._time is not MISSING: - # the time index should be prepared every time the internal loop is started - self._prepare_time_index() + if self._expl_time(): self._next_iteration = self._get_next_sleep_time() else: self._next_iteration = datetime.datetime.now(datetime.timezone.utc) + try: - await self._try_sleep_until(self._next_iteration) + if self._stop_next_iteration: + return + while True: + if self._expl_time(): + await self._try_sleep_until(self._next_iteration) if not self._last_iteration_failed: self._last_iteration = self._next_iteration self._next_iteration = self._get_next_sleep_time() + + while ( + self._expl_time() + and self._next_iteration <= self._last_iteration + ): + _log.warning( + "Task %s woke up at %s, which was before expected (%s). Sleeping again to fix it...", + self.coro.__name__, + discord.utils.utcnow(), + self._next_iteration, + ) + await self._try_sleep_until(self._next_iteration) + self._next_iteration = self._get_next_sleep_time() try: await self.coro(*args, **kwargs) self._last_iteration_failed = False - backoff = ExponentialBackoff() - except self._valid_exception: + except self._valid_exception as exc: self._last_iteration_failed = True if not self.reconnect: raise - await asyncio.sleep(backoff.delay()) - else: - await self._try_sleep_until(self._next_iteration) + delay = backoff.delay() + _log.warning( + "Received an exception which was in the valid exception set. Task will run again in %s.2f seconds", + self.coro.__name__, + delay, + exc_info=exc, + ) + await asyncio.sleep(delay) + else: if self._stop_next_iteration: return - now = datetime.datetime.now(datetime.timezone.utc) - if now > self._next_iteration: - self._next_iteration = now - if self._time is not MISSING: - self._prepare_time_index(now) + if self._rel_time(): + await self._try_sleep_until(self._next_iteration) self._current_loop += 1 if self._current_loop == self.count: @@ -199,7 +278,8 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: raise exc finally: await self._call_loop_function("after_loop") - self._handle.cancel() + if self._handle: + self._handle.cancel() self._is_being_cancelled = False self._current_loop = 0 self._stop_next_iteration = False @@ -217,7 +297,9 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]: time=self._time, count=self.count, reconnect=self.reconnect, + name=self.name, loop=self.loop, + create_loop=self._create_loop, ) copy._injected = obj copy._before_loop = self._before_loop @@ -303,7 +385,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: return await self.coro(*args, **kwargs) - def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: + def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None] | None: r"""Starts the internal task in the event loop. Parameters @@ -324,16 +406,24 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: The task that has been created. """ + if self.loop is None: + _log.warning( + f"The task {self.name} has been set to be bound to a discord.Client instance, and will start running automatically " + "when the client starts. If you want this task to be executed without it being bound to a discord.Client, " + "set the create_loop parameter in the decorator to True, and don't forget to set the client.loop to the loop.loop" + ) + return None + if self._task is not MISSING and not self._task.done(): raise RuntimeError("Task is already launched and is not completed.") if self._injected is not None: args = (self._injected, *args) - if self.loop is MISSING: - self.loop = asyncio.get_event_loop() - - self._task = self.loop.create_task(self._loop(*args, **kwargs)) + self._task = asyncio.ensure_future( + self.loop.create_task(self._loop(*args, **kwargs), name=self.name), + loop=self.loop, + ) return self._task def stop(self) -> None: @@ -567,66 +657,53 @@ def error(self, coro: ET) -> ET: self._error = coro # type: ignore return coro - def _get_next_sleep_time(self) -> datetime.datetime: + def _get_next_sleep_time( + self, now: datetime.datetime = MISSING + ) -> datetime.datetime: if self._sleep is not MISSING: return self._last_iteration + datetime.timedelta(seconds=self._sleep) - if self._time_index >= len(self._time): - self._time_index = 0 - if self._current_loop == 0: - # if we're at the last index on the first iteration, we need to sleep until tomorrow - return datetime.datetime.combine( - datetime.datetime.now(self._time[0].tzinfo or datetime.timezone.utc) - + datetime.timedelta(days=1), - self._time[0], - ) + if now is MISSING: + now = datetime.datetime.now(datetime.timezone.utc) - next_time = self._time[self._time_index] - - if self._current_loop == 0: - self._time_index += 1 - if ( - next_time - > datetime.datetime.now( - next_time.tzinfo or datetime.timezone.utc - ).timetz() - ): - return datetime.datetime.combine( - datetime.datetime.now(next_time.tzinfo or datetime.timezone.utc), - next_time, - ) - else: - return datetime.datetime.combine( - datetime.datetime.now(next_time.tzinfo or datetime.timezone.utc) - + datetime.timedelta(days=1), - next_time, - ) + index = self._start_time_relative_to(now) - next_date = cast( - datetime.datetime, self._last_iteration.astimezone(next_time.tzinfo) - ) - if next_time < next_date.timetz(): - next_date += datetime.timedelta(days=1) + if index is None: + time = self._time[0] + tomorrow = now.astimezone(time.tzinfo) + datetime.timedelta(days=1) + date = tomorrow.date() + else: + time = self._time[index] + date = now.astimezone(time.tzinfo).date() - self._time_index += 1 - return datetime.datetime.combine(next_date, next_time) + dt = datetime.datetime.combine(date, time, tzinfo=time.tzinfo) + + if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone): + return dt + + if is_imaginary(dt): + tomorrow = dt + datetime.timedelta(days=1) + yesterday = dt - datetime.timedelta(days=1) + return dt + (tomorrow.utcoffset() - yesterday.utcoffset()) # type: ignore + elif is_ambiguous(dt): + return dt.replace(fold=1) + else: + return dt - def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: + def _start_time_relative_to(self, now: datetime.datetime) -> int | None: # now kwarg should be a datetime.datetime representing the time "now" # to calculate the next time index from # pre-condition: self._time is set - time_now = ( - now - if now is not MISSING - else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) - ) for idx, time in enumerate(self._time): - if time >= time_now.astimezone(time.tzinfo).timetz(): - self._time_index = idx - break + # Convert the current time to the target timezone + # e.g. 18:00 UTC -> 03:00 UTC+9 + # Then compare the time instances to see if they're the same + start = now.astimezone(time.tzinfo) + if time >= start.timetz(): + return idx else: - self._time_index = 0 + return None def _get_time_parameter( self, @@ -716,15 +793,9 @@ def change_interval( self._time = self._get_time_parameter(time) self._sleep = self._seconds = self._minutes = self._hours = MISSING - if self.is_running() and not ( - self._before_loop_running or self._after_loop_running - ): - if self._time is not MISSING: - # prepare the next time index starting from after the last iteration - self._prepare_time_index(now=self._last_iteration) - + if self.is_running() and self._last_iteration is not MISSING: self._next_iteration = self._get_next_sleep_time() - if not self._handle.done(): + if self._handle and not self._handle.done(): # the loop is sleeping, recalculate based on new interval self._handle.recalculate(self._next_iteration) @@ -737,7 +808,9 @@ def loop( time: datetime.time | Sequence[datetime.time] = MISSING, count: int | None = None, reconnect: bool = True, - loop: asyncio.AbstractEventLoop = MISSING, + loop: asyncio.AbstractEventLoop | None = None, + name: str | None = MISSING, + create_loop: bool = False, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -770,9 +843,17 @@ def loop( Whether to handle errors and restart the task using an exponential back-off algorithm similar to the one used in :meth:`discord.Client.connect`. - loop: :class:`asyncio.AbstractEventLoop` - The loop to use to register the task, if not given - defaults to :func:`asyncio.get_event_loop`. + loop: Optional[:class:`asyncio.AbstractEventLoop`] + The loop to use to register the task, defaults to ``None``. + name: Optional[:class:`str`] + The name to create the task with, defaults to ``None``. + + .. versionadded:: 2.7 + create_loop: :class:`bool` + Whether this task should create their own event loop to start running it + without a client bound to it. + + .. versionadded:: 2.7 Raises ------ @@ -792,7 +873,9 @@ def decorator(func: LF) -> Loop[LF]: count=count, time=time, reconnect=reconnect, + name=name, loop=loop, + create_loop=create_loop, ) return decorator diff --git a/discord/http.py b/discord/http.py index bfefed91d1..0f530af088 100644 --- a/discord/http.py +++ b/discord/http.py @@ -177,9 +177,7 @@ def __init__( loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + self.loop: asyncio.AbstractEventLoop = loop or MISSING self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() diff --git a/discord/state.py b/discord/state.py index dc982d43bb..3de8943ad4 100644 --- a/discord/state.py +++ b/discord/state.py @@ -95,6 +95,8 @@ CS = TypeVar("CS", bound="ConnectionState") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] +MISSING = utils.MISSING + class ChunkRequest: def __init__( @@ -167,16 +169,16 @@ def __init__( handlers: dict[str, Callable], hooks: dict[str, Callable], http: HTTPClient, - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop | None, **options: Any, ) -> None: - self.loop: asyncio.AbstractEventLoop = loop + self.loop: asyncio.AbstractEventLoop = loop or MISSING self.http: HTTPClient = http self.max_messages: int | None = options.get("max_messages", 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 - self.dispatch: Callable = dispatch + self._dispatch: Callable = dispatch self.handlers: dict[str, Callable] = handlers self.hooks: dict[str, Callable] = hooks self.shard_count: int | None = None @@ -261,6 +263,10 @@ def __init__( self.clear() + def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: + _log.debug("Dispatching event %s", event) + return self._dispatch(event, *args, **kwargs) + def clear(self, *, views: bool = True) -> None: self.user: ClientUser | None = None # Originally, this code used WeakValueDictionary to maintain references to the diff --git a/discord/utils.py b/discord/utils.py index c42a51cbd8..ba1e8f3840 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -31,6 +31,7 @@ import datetime import functools import importlib.resources +import inspect import itertools import json import logging @@ -1377,7 +1378,7 @@ def _filter(ctx: AutocompleteContext, item: Any) -> bool: gen = (val for val in _values if _filter(ctx, val)) - elif asyncio.iscoroutinefunction(filter): + elif inspect.iscoroutinefunction(filter): gen = (val for val in _values if await filter(ctx, val)) elif callable(filter):