Skip to content

Commit daab987

Browse files
committed
fix tasks things
1 parent b620fcd commit daab987

File tree

2 files changed

+121
-80
lines changed

2 files changed

+121
-80
lines changed

discord/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def __init__(
235235
):
236236
# self.ws is set in the connect method
237237
self.ws: DiscordWebSocket = None # type: ignore
238+
239+
if loop is None:
240+
try:
241+
loop = asyncio.get_running_loop()
242+
except RuntimeError:
243+
pass
244+
238245
self._loop: asyncio.AbstractEventLoop | None = loop
239246
self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = (
240247
{}

discord/ext/tasks/__init__.py

Lines changed: 114 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import asyncio
2929
import datetime
3030
import inspect
31+
import logging
3132
import sys
3233
import traceback
3334
from collections.abc import Sequence
@@ -43,26 +44,54 @@
4344

4445
T = TypeVar("T")
4546
_func = Callable[..., Awaitable[Any]]
47+
_log = logging.getLogger(__name__)
4648
LF = TypeVar("LF", bound=_func)
4749
FT = TypeVar("FT", bound=_func)
4850
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
4951

5052

53+
def is_ambiguous(dt: datetime.datetime) -> bool:
54+
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
55+
return False
56+
57+
before = dt.replace(fold=0)
58+
after = dt.replace(fold=1)
59+
60+
same_offset = before.utcoffset() == after.utcoffset()
61+
same_dst = before.dst() == after.dst()
62+
return not (same_offset and same_dst)
63+
64+
65+
def is_imaginary(dt: datetime.datetime) -> bool:
66+
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
67+
return False
68+
69+
tz = dt.tzinfo
70+
dt = dt.replace(tzinfo=None)
71+
roundtrip = dt.replace(tzinfo=tz).astimezone(datetime.timezone.utc).astimezone(tz).replace(tzinfo=None)
72+
return dt != roundtrip
73+
74+
5175
class SleepHandle:
5276
__slots__ = ("future", "loop", "handle")
5377

5478
def __init__(
5579
self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop
5680
) -> None:
57-
self.loop = loop
58-
self.future = future = loop.create_future()
81+
self.loop: asyncio.AbstractEventLoop = loop
82+
self.future: asyncio.Future[None] = loop.create_future()
5983
relative_delta = discord.utils.compute_timedelta(dt)
60-
self.handle = loop.call_later(relative_delta, future.set_result, True)
84+
self.handle = loop.call_later(relative_delta, self._safe_result, self.future)
85+
86+
@staticmethod
87+
def _safe_result(future: asyncio.Future) -> None:
88+
if not future.done():
89+
future.set_result(None)
6190

6291
def recalculate(self, dt: datetime.datetime) -> None:
6392
self.handle.cancel()
6493
relative_delta = discord.utils.compute_timedelta(dt)
65-
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
94+
self.handle = self.loop.call_later(relative_delta, self._safe_result, self.future)
6695

6796
def wait(self) -> asyncio.Future[Any]:
6897
return self.future
@@ -95,7 +124,15 @@ def __init__(
95124
) -> None:
96125
self.coro: LF = coro
97126
self.reconnect: bool = reconnect
98-
self.loop: asyncio.AbstractEventLoop | None = loop
127+
128+
if loop is None:
129+
try:
130+
loop = asyncio.get_running_loop()
131+
except RuntimeError:
132+
loop = asyncio.new_event_loop()
133+
134+
self.loop = loop
135+
99136
self.name: str = f'pycord-ext-task ({id(self):#x}): {coro.__qualname__}' if name in (None, MISSING) else name
100137
self.count: int | None = count
101138
self._current_loop = 0
@@ -147,53 +184,67 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
147184
if name.endswith("_loop"):
148185
setattr(self, f"_{name}_running", False)
149186

150-
def _create_task(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
151-
if self.loop is None:
152-
meth = asyncio.create_task
153-
else:
154-
meth = self.loop.create_task
155-
return meth(self._loop(*args, **kwargs), name=self.name)
156-
157187
def _try_sleep_until(self, dt: datetime.datetime):
158188
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
159189
return self._handle.wait()
160190

191+
def _rel_time(self) -> bool:
192+
return self._time is MISSING
193+
194+
def _expl_time(self) -> bool:
195+
return self._time is not MISSING
196+
161197
async def _loop(self, *args: Any, **kwargs: Any) -> None:
162198
backoff = ExponentialBackoff()
163199
await self._call_loop_function("before_loop")
164200
self._last_iteration_failed = False
165-
if self._time is not MISSING:
166-
# the time index should be prepared every time the internal loop is started
167-
self._prepare_time_index()
201+
if self._expl_time():
168202
self._next_iteration = self._get_next_sleep_time()
169203
else:
170204
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
205+
171206
try:
172-
await self._try_sleep_until(self._next_iteration)
207+
if self._stop_next_iteration:
208+
return
209+
173210
while True:
211+
if self._expl_time():
212+
await self._try_sleep_until(self._next_iteration)
174213
if not self._last_iteration_failed:
175214
self._last_iteration = self._next_iteration
176215
self._next_iteration = self._get_next_sleep_time()
216+
217+
while self._expl_time() and self._next_iteration <= self._last_iteration:
218+
_log.warning(
219+
'Task %s woke up at %s, which was before expected (%s). Sleeping again to fix it...',
220+
self.coro.__name__,
221+
discord.utils.utcnow(),
222+
self._next_iteration,
223+
)
224+
await self._try_sleep_until(self._next_iteration)
225+
self._next_iteration = self._get_next_sleep_time()
177226
try:
178227
await self.coro(*args, **kwargs)
179228
self._last_iteration_failed = False
180-
backoff = ExponentialBackoff()
181-
except self._valid_exception:
229+
except self._valid_exception as exc:
182230
self._last_iteration_failed = True
183231
if not self.reconnect:
184232
raise
185-
await asyncio.sleep(backoff.delay())
186-
else:
187-
await self._try_sleep_until(self._next_iteration)
188233

234+
delay = backoff.delay()
235+
_log.warning(
236+
'Received an exception which was in the valid exception set. Task will run again in %s.2f seconds',
237+
self.coro.__name__,
238+
delay,
239+
exc_info=exc,
240+
)
241+
await asyncio.sleep(delay)
242+
else:
189243
if self._stop_next_iteration:
190244
return
191245

192-
now = datetime.datetime.now(datetime.timezone.utc)
193-
if now > self._next_iteration:
194-
self._next_iteration = now
195-
if self._time is not MISSING:
196-
self._prepare_time_index(now)
246+
if self._rel_time():
247+
await self._try_sleep_until(self._next_iteration)
197248

198249
self._current_loop += 1
199250
if self._current_loop == self.count:
@@ -208,7 +259,8 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
208259
raise exc
209260
finally:
210261
await self._call_loop_function("after_loop")
211-
self._handle.cancel()
262+
if self._handle:
263+
self._handle.cancel()
212264
self._is_being_cancelled = False
213265
self._current_loop = 0
214266
self._stop_next_iteration = False
@@ -226,8 +278,8 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
226278
time=self._time,
227279
count=self.count,
228280
reconnect=self.reconnect,
229-
loop=self.loop,
230281
name=self.name,
282+
loop=self.loop,
231283
)
232284
copy._injected = obj
233285
copy._before_loop = self._before_loop
@@ -340,7 +392,7 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
340392
if self._injected is not None:
341393
args = (self._injected, *args)
342394

343-
self._task = self._create_task(*args, **kwargs)
395+
self._task = self.loop.create_task(self._loop(*args, **kwargs), name=self.name)
344396
return self._task
345397

346398
def stop(self) -> None:
@@ -574,66 +626,51 @@ def error(self, coro: ET) -> ET:
574626
self._error = coro # type: ignore
575627
return coro
576628

577-
def _get_next_sleep_time(self) -> datetime.datetime:
629+
def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
578630
if self._sleep is not MISSING:
579631
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
580632

581-
if self._time_index >= len(self._time):
582-
self._time_index = 0
583-
if self._current_loop == 0:
584-
# if we're at the last index on the first iteration, we need to sleep until tomorrow
585-
return datetime.datetime.combine(
586-
datetime.datetime.now(self._time[0].tzinfo or datetime.timezone.utc)
587-
+ datetime.timedelta(days=1),
588-
self._time[0],
589-
)
633+
if now is MISSING:
634+
now = datetime.datetime.now(datetime.timezone.utc)
590635

591-
next_time = self._time[self._time_index]
592-
593-
if self._current_loop == 0:
594-
self._time_index += 1
595-
if (
596-
next_time
597-
> datetime.datetime.now(
598-
next_time.tzinfo or datetime.timezone.utc
599-
).timetz()
600-
):
601-
return datetime.datetime.combine(
602-
datetime.datetime.now(next_time.tzinfo or datetime.timezone.utc),
603-
next_time,
604-
)
605-
else:
606-
return datetime.datetime.combine(
607-
datetime.datetime.now(next_time.tzinfo or datetime.timezone.utc)
608-
+ datetime.timedelta(days=1),
609-
next_time,
610-
)
636+
index = self._start_time_relative_to(now)
611637

612-
next_date = cast(
613-
datetime.datetime, self._last_iteration.astimezone(next_time.tzinfo)
614-
)
615-
if next_time < next_date.timetz():
616-
next_date += datetime.timedelta(days=1)
638+
if index is None:
639+
time = self._time[0]
640+
tomorrow = now.astimezone(time.tzinfo) + datetime.timedelta(days=1)
641+
date = tomorrow.date()
642+
else:
643+
time = self._time[index]
644+
date = now.astimezone(time.tzinfo).date()
645+
646+
dt = datetime.datetime.combine(date, time, tzinfo=time.tzinfo)
617647

618-
self._time_index += 1
619-
return datetime.datetime.combine(next_date, next_time)
648+
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
649+
return dt
650+
651+
if is_imaginary(dt):
652+
tomorrow = dt + datetime.timedelta(days=1)
653+
yesterday = dt - datetime.timedelta(days=1)
654+
return dt + (tomorrow.utcoffset() - yesterday.utcoffset()) # type: ignore
655+
elif is_ambiguous(dt):
656+
return dt.replace(fold=1)
657+
else:
658+
return dt
620659

621-
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
660+
def _start_time_relative_to(self, now: datetime.datetime) -> int | None:
622661
# now kwarg should be a datetime.datetime representing the time "now"
623662
# to calculate the next time index from
624663

625664
# pre-condition: self._time is set
626-
time_now = (
627-
now
628-
if now is not MISSING
629-
else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
630-
)
631665
for idx, time in enumerate(self._time):
632-
if time >= time_now.astimezone(time.tzinfo).timetz():
633-
self._time_index = idx
634-
break
666+
# Convert the current time to the target timezone
667+
# e.g. 18:00 UTC -> 03:00 UTC+9
668+
# Then compare the time instances to see if they're the same
669+
start = now.astimezone(time.tzinfo)
670+
if time >= start.timetz():
671+
return idx
635672
else:
636-
self._time_index = 0
673+
return None
637674

638675
def _get_time_parameter(
639676
self,
@@ -780,9 +817,6 @@ def loop(
780817
one used in :meth:`discord.Client.connect`.
781818
loop: Optional[:class:`asyncio.AbstractEventLoop`]
782819
The loop to use to register the task, defaults to ``None``.
783-
784-
.. versionchanged:: 2.7
785-
This can now be ``None``
786820
name: Optional[:class:`str`]
787821
The name to create the task with, defaults to ``None``.
788822
@@ -806,8 +840,8 @@ def decorator(func: LF) -> Loop[LF]:
806840
count=count,
807841
time=time,
808842
reconnect=reconnect,
809-
loop=loop,
810843
name=name,
844+
loop=loop,
811845
)
812846

813847
return decorator

0 commit comments

Comments
 (0)