28
28
import asyncio
29
29
import datetime
30
30
import inspect
31
+ import logging
31
32
import sys
32
33
import traceback
33
34
from collections .abc import Sequence
43
44
44
45
T = TypeVar ("T" )
45
46
_func = Callable [..., Awaitable [Any ]]
47
+ _log = logging .getLogger (__name__ )
46
48
LF = TypeVar ("LF" , bound = _func )
47
49
FT = TypeVar ("FT" , bound = _func )
48
50
ET = TypeVar ("ET" , bound = Callable [[Any , BaseException ], Awaitable [Any ]])
49
51
50
52
53
+ def is_ambiguous (dt : datetime .datetime ) -> bool :
54
+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
55
+ return False
56
+
57
+ before = dt .replace (fold = 0 )
58
+ after = dt .replace (fold = 1 )
59
+
60
+ same_offset = before .utcoffset () == after .utcoffset ()
61
+ same_dst = before .dst () == after .dst ()
62
+ return not (same_offset and same_dst )
63
+
64
+
65
+ def is_imaginary (dt : datetime .datetime ) -> bool :
66
+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
67
+ return False
68
+
69
+ tz = dt .tzinfo
70
+ dt = dt .replace (tzinfo = None )
71
+ roundtrip = dt .replace (tzinfo = tz ).astimezone (datetime .timezone .utc ).astimezone (tz ).replace (tzinfo = None )
72
+ return dt != roundtrip
73
+
74
+
51
75
class SleepHandle :
52
76
__slots__ = ("future" , "loop" , "handle" )
53
77
54
78
def __init__ (
55
79
self , dt : datetime .datetime , * , loop : asyncio .AbstractEventLoop
56
80
) -> None :
57
- self .loop = loop
58
- self .future = future = loop .create_future ()
81
+ self .loop : asyncio . AbstractEventLoop = loop
82
+ self .future : asyncio . Future [ None ] = loop .create_future ()
59
83
relative_delta = discord .utils .compute_timedelta (dt )
60
- self .handle = loop .call_later (relative_delta , future .set_result , True )
84
+ self .handle = loop .call_later (relative_delta , self ._safe_result , self .future )
85
+
86
+ @staticmethod
87
+ def _safe_result (future : asyncio .Future ) -> None :
88
+ if not future .done ():
89
+ future .set_result (None )
61
90
62
91
def recalculate (self , dt : datetime .datetime ) -> None :
63
92
self .handle .cancel ()
64
93
relative_delta = discord .utils .compute_timedelta (dt )
65
- self .handle = self .loop .call_later (relative_delta , self .future . set_result , True )
94
+ self .handle = self .loop .call_later (relative_delta , self ._safe_result , self . future )
66
95
67
96
def wait (self ) -> asyncio .Future [Any ]:
68
97
return self .future
@@ -95,7 +124,15 @@ def __init__(
95
124
) -> None :
96
125
self .coro : LF = coro
97
126
self .reconnect : bool = reconnect
98
- self .loop : asyncio .AbstractEventLoop | None = loop
127
+
128
+ if loop is None :
129
+ try :
130
+ loop = asyncio .get_running_loop ()
131
+ except RuntimeError :
132
+ loop = asyncio .new_event_loop ()
133
+
134
+ self .loop = loop
135
+
99
136
self .name : str = f'pycord-ext-task ({ id (self ):#x} ): { coro .__qualname__ } ' if name in (None , MISSING ) else name
100
137
self .count : int | None = count
101
138
self ._current_loop = 0
@@ -147,53 +184,67 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
147
184
if name .endswith ("_loop" ):
148
185
setattr (self , f"_{ name } _running" , False )
149
186
150
- def _create_task (self , * args : Any , ** kwargs : Any ) -> asyncio .Task [None ]:
151
- if self .loop is None :
152
- meth = asyncio .create_task
153
- else :
154
- meth = self .loop .create_task
155
- return meth (self ._loop (* args , ** kwargs ), name = self .name )
156
-
157
187
def _try_sleep_until (self , dt : datetime .datetime ):
158
188
self ._handle = SleepHandle (dt = dt , loop = asyncio .get_running_loop ())
159
189
return self ._handle .wait ()
160
190
191
+ def _rel_time (self ) -> bool :
192
+ return self ._time is MISSING
193
+
194
+ def _expl_time (self ) -> bool :
195
+ return self ._time is not MISSING
196
+
161
197
async def _loop (self , * args : Any , ** kwargs : Any ) -> None :
162
198
backoff = ExponentialBackoff ()
163
199
await self ._call_loop_function ("before_loop" )
164
200
self ._last_iteration_failed = False
165
- if self ._time is not MISSING :
166
- # the time index should be prepared every time the internal loop is started
167
- self ._prepare_time_index ()
201
+ if self ._expl_time ():
168
202
self ._next_iteration = self ._get_next_sleep_time ()
169
203
else :
170
204
self ._next_iteration = datetime .datetime .now (datetime .timezone .utc )
205
+
171
206
try :
172
- await self ._try_sleep_until (self ._next_iteration )
207
+ if self ._stop_next_iteration :
208
+ return
209
+
173
210
while True :
211
+ if self ._expl_time ():
212
+ await self ._try_sleep_until (self ._next_iteration )
174
213
if not self ._last_iteration_failed :
175
214
self ._last_iteration = self ._next_iteration
176
215
self ._next_iteration = self ._get_next_sleep_time ()
216
+
217
+ while self ._expl_time () and self ._next_iteration <= self ._last_iteration :
218
+ _log .warning (
219
+ 'Task %s woke up at %s, which was before expected (%s). Sleeping again to fix it...' ,
220
+ self .coro .__name__ ,
221
+ discord .utils .utcnow (),
222
+ self ._next_iteration ,
223
+ )
224
+ await self ._try_sleep_until (self ._next_iteration )
225
+ self ._next_iteration = self ._get_next_sleep_time ()
177
226
try :
178
227
await self .coro (* args , ** kwargs )
179
228
self ._last_iteration_failed = False
180
- backoff = ExponentialBackoff ()
181
- except self ._valid_exception :
229
+ except self ._valid_exception as exc :
182
230
self ._last_iteration_failed = True
183
231
if not self .reconnect :
184
232
raise
185
- await asyncio .sleep (backoff .delay ())
186
- else :
187
- await self ._try_sleep_until (self ._next_iteration )
188
233
234
+ delay = backoff .delay ()
235
+ _log .warning (
236
+ 'Received an exception which was in the valid exception set. Task will run again in %s.2f seconds' ,
237
+ self .coro .__name__ ,
238
+ delay ,
239
+ exc_info = exc ,
240
+ )
241
+ await asyncio .sleep (delay )
242
+ else :
189
243
if self ._stop_next_iteration :
190
244
return
191
245
192
- now = datetime .datetime .now (datetime .timezone .utc )
193
- if now > self ._next_iteration :
194
- self ._next_iteration = now
195
- if self ._time is not MISSING :
196
- self ._prepare_time_index (now )
246
+ if self ._rel_time ():
247
+ await self ._try_sleep_until (self ._next_iteration )
197
248
198
249
self ._current_loop += 1
199
250
if self ._current_loop == self .count :
@@ -208,7 +259,8 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
208
259
raise exc
209
260
finally :
210
261
await self ._call_loop_function ("after_loop" )
211
- self ._handle .cancel ()
262
+ if self ._handle :
263
+ self ._handle .cancel ()
212
264
self ._is_being_cancelled = False
213
265
self ._current_loop = 0
214
266
self ._stop_next_iteration = False
@@ -226,8 +278,8 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
226
278
time = self ._time ,
227
279
count = self .count ,
228
280
reconnect = self .reconnect ,
229
- loop = self .loop ,
230
281
name = self .name ,
282
+ loop = self .loop ,
231
283
)
232
284
copy ._injected = obj
233
285
copy ._before_loop = self ._before_loop
@@ -340,7 +392,7 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
340
392
if self ._injected is not None :
341
393
args = (self ._injected , * args )
342
394
343
- self ._task = self ._create_task ( * args , ** kwargs )
395
+ self ._task = self .loop . create_task ( self . _loop ( * args , ** kwargs ), name = self . name )
344
396
return self ._task
345
397
346
398
def stop (self ) -> None :
@@ -574,66 +626,51 @@ def error(self, coro: ET) -> ET:
574
626
self ._error = coro # type: ignore
575
627
return coro
576
628
577
- def _get_next_sleep_time (self ) -> datetime .datetime :
629
+ def _get_next_sleep_time (self , now : datetime . datetime = MISSING ) -> datetime .datetime :
578
630
if self ._sleep is not MISSING :
579
631
return self ._last_iteration + datetime .timedelta (seconds = self ._sleep )
580
632
581
- if self ._time_index >= len (self ._time ):
582
- self ._time_index = 0
583
- if self ._current_loop == 0 :
584
- # if we're at the last index on the first iteration, we need to sleep until tomorrow
585
- return datetime .datetime .combine (
586
- datetime .datetime .now (self ._time [0 ].tzinfo or datetime .timezone .utc )
587
- + datetime .timedelta (days = 1 ),
588
- self ._time [0 ],
589
- )
633
+ if now is MISSING :
634
+ now = datetime .datetime .now (datetime .timezone .utc )
590
635
591
- next_time = self ._time [self ._time_index ]
592
-
593
- if self ._current_loop == 0 :
594
- self ._time_index += 1
595
- if (
596
- next_time
597
- > datetime .datetime .now (
598
- next_time .tzinfo or datetime .timezone .utc
599
- ).timetz ()
600
- ):
601
- return datetime .datetime .combine (
602
- datetime .datetime .now (next_time .tzinfo or datetime .timezone .utc ),
603
- next_time ,
604
- )
605
- else :
606
- return datetime .datetime .combine (
607
- datetime .datetime .now (next_time .tzinfo or datetime .timezone .utc )
608
- + datetime .timedelta (days = 1 ),
609
- next_time ,
610
- )
636
+ index = self ._start_time_relative_to (now )
611
637
612
- next_date = cast (
613
- datetime .datetime , self ._last_iteration .astimezone (next_time .tzinfo )
614
- )
615
- if next_time < next_date .timetz ():
616
- next_date += datetime .timedelta (days = 1 )
638
+ if index is None :
639
+ time = self ._time [0 ]
640
+ tomorrow = now .astimezone (time .tzinfo ) + datetime .timedelta (days = 1 )
641
+ date = tomorrow .date ()
642
+ else :
643
+ time = self ._time [index ]
644
+ date = now .astimezone (time .tzinfo ).date ()
645
+
646
+ dt = datetime .datetime .combine (date , time , tzinfo = time .tzinfo )
617
647
618
- self ._time_index += 1
619
- return datetime .datetime .combine (next_date , next_time )
648
+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
649
+ return dt
650
+
651
+ if is_imaginary (dt ):
652
+ tomorrow = dt + datetime .timedelta (days = 1 )
653
+ yesterday = dt - datetime .timedelta (days = 1 )
654
+ return dt + (tomorrow .utcoffset () - yesterday .utcoffset ()) # type: ignore
655
+ elif is_ambiguous (dt ):
656
+ return dt .replace (fold = 1 )
657
+ else :
658
+ return dt
620
659
621
- def _prepare_time_index (self , now : datetime .datetime = MISSING ) -> None :
660
+ def _start_time_relative_to (self , now : datetime .datetime ) -> int | None :
622
661
# now kwarg should be a datetime.datetime representing the time "now"
623
662
# to calculate the next time index from
624
663
625
664
# pre-condition: self._time is set
626
- time_now = (
627
- now
628
- if now is not MISSING
629
- else datetime .datetime .now (datetime .timezone .utc ).replace (microsecond = 0 )
630
- )
631
665
for idx , time in enumerate (self ._time ):
632
- if time >= time_now .astimezone (time .tzinfo ).timetz ():
633
- self ._time_index = idx
634
- break
666
+ # Convert the current time to the target timezone
667
+ # e.g. 18:00 UTC -> 03:00 UTC+9
668
+ # Then compare the time instances to see if they're the same
669
+ start = now .astimezone (time .tzinfo )
670
+ if time >= start .timetz ():
671
+ return idx
635
672
else :
636
- self . _time_index = 0
673
+ return None
637
674
638
675
def _get_time_parameter (
639
676
self ,
@@ -780,9 +817,6 @@ def loop(
780
817
one used in :meth:`discord.Client.connect`.
781
818
loop: Optional[:class:`asyncio.AbstractEventLoop`]
782
819
The loop to use to register the task, defaults to ``None``.
783
-
784
- .. versionchanged:: 2.7
785
- This can now be ``None``
786
820
name: Optional[:class:`str`]
787
821
The name to create the task with, defaults to ``None``.
788
822
@@ -806,8 +840,8 @@ def decorator(func: LF) -> Loop[LF]:
806
840
count = count ,
807
841
time = time ,
808
842
reconnect = reconnect ,
809
- loop = loop ,
810
843
name = name ,
844
+ loop = loop ,
811
845
)
812
846
813
847
return decorator
0 commit comments