@@ -80,17 +80,14 @@ class Timeout:
80
80
# The purpose is to time out as sson as possible
81
81
# without waiting for the next await expression.
82
82
83
- __slots__ = ("_deadline" , "_loop" , "_state" , "_task" , " _timeout_handler" )
83
+ __slots__ = ("_deadline" , "_loop" , "_state" , "_timeout_handler" )
84
84
85
85
def __init__ (
86
86
self , deadline : Optional [float ], loop : asyncio .AbstractEventLoop
87
87
) -> None :
88
88
self ._loop = loop
89
89
self ._state = _State .INIT
90
90
91
- task = _current_task (self ._loop )
92
- self ._task = task
93
-
94
91
self ._timeout_handler = None # type: Optional[asyncio.Handle]
95
92
if deadline is None :
96
93
self ._deadline = None # type: Optional[float]
@@ -180,22 +177,30 @@ def update(self, deadline: float) -> None:
180
177
if self ._timeout_handler is not None :
181
178
self ._timeout_handler .cancel ()
182
179
self ._deadline = deadline
180
+ if self ._state != _State .INIT :
181
+ self ._reschedule ()
182
+
183
+ def _reschedule (self ) -> None :
184
+ assert self ._state == _State .ENTER
185
+ deadline = self ._deadline
186
+ if deadline is None :
187
+ return
188
+
183
189
now = self ._loop .time ()
190
+ if self ._timeout_handler is not None :
191
+ self ._timeout_handler .cancel ()
192
+
193
+ task = _current_task (self ._loop )
184
194
if deadline <= now :
185
- self ._timeout_handler = None
186
- if self ._state == _State .INIT :
187
- raise asyncio .TimeoutError
188
- else :
189
- # state is ENTER
190
- raise asyncio .CancelledError
191
- self ._timeout_handler = self ._loop .call_at (
192
- deadline , self ._on_timeout , self ._task
193
- )
195
+ self ._timeout_handler = self ._loop .call_soon (self ._on_timeout , task )
196
+ else :
197
+ self ._timeout_handler = self ._loop .call_at (deadline , self ._on_timeout , task )
194
198
195
199
def _do_enter (self ) -> None :
196
200
if self ._state != _State .INIT :
197
201
raise RuntimeError (f"invalid state { self ._state .value } " )
198
202
self ._state = _State .ENTER
203
+ self ._reschedule ()
199
204
200
205
def _do_exit (self , exc_type : Type [BaseException ]) -> None :
201
206
if exc_type is asyncio .CancelledError and self ._state == _State .TIMEOUT :
@@ -209,6 +214,8 @@ def _do_exit(self, exc_type: Type[BaseException]) -> None:
209
214
def _on_timeout (self , task : "asyncio.Task[None]" ) -> None :
210
215
task .cancel ()
211
216
self ._state = _State .TIMEOUT
217
+ # drop the reference early
218
+ self ._timeout_handler = None
212
219
213
220
214
221
if sys .version_info >= (3 , 7 ):
0 commit comments