55import logging
66import traceback
77import urllib .parse
8- from collections import deque
98from contextlib import suppress
109from typing import Any , ClassVar , Final , Protocol , TypeAlias
1110from uuid import uuid4
@@ -73,10 +72,10 @@ async def _await_task(task: asyncio.Task) -> None:
7372async def _get_tasks_to_remove (
7473 tracked_tasks : BaseStore ,
7574 stale_task_detect_timeout_s : PositiveFloat ,
76- ) -> list [tuple [TaskId , TaskContext | None ]]:
75+ ) -> list [tuple [TaskId , TaskContext ]]:
7776 utc_now = datetime .datetime .now (tz = datetime .UTC )
7877
79- tasks_to_remove : list [tuple [TaskId , TaskContext | None ]] = []
78+ tasks_to_remove : list [tuple [TaskId , TaskContext ]] = []
8079
8180 for tracked_task in await tracked_tasks .list_tasks_data ():
8281 if tracked_task .fire_and_forget :
@@ -142,14 +141,12 @@ async def setup(self) -> None:
142141 )
143142
144143 async def teardown (self ) -> None :
145- task_ids_to_remove : deque [TaskId ] = deque ()
146144
147145 for tracked_task in await self ._tasks_data .list_tasks_data ():
148- task_ids_to_remove .append (tracked_task .task_id )
149-
150- for task_id in task_ids_to_remove :
151146 # when closing we do not care about pending errors
152- await self .remove_task (task_id , None , reraise_errors = False )
147+ await self .remove_task (
148+ tracked_task .task_id , tracked_task .task_context , reraise_errors = False
149+ )
153150
154151 if self ._stale_tasks_monitor_task :
155152 with log_catch (_logger , reraise = False ):
@@ -248,7 +245,7 @@ async def _add_task(
248245 return task_data
249246
250247 async def _get_tracked_task (
251- self , task_id : TaskId , with_task_context : TaskContext | None
248+ self , task_id : TaskId , with_task_context : TaskContext
252249 ) -> TaskData :
253250 task_data = await self ._tasks_data .get_task_data (task_id )
254251
@@ -261,7 +258,7 @@ async def _get_tracked_task(
261258 return task_data
262259
263260 async def get_task_status (
264- self , task_id : TaskId , with_task_context : TaskContext | None
261+ self , task_id : TaskId , with_task_context : TaskContext
265262 ) -> TaskStatus :
266263 """
267264 returns: the status of the task, along with updates
@@ -285,7 +282,7 @@ async def get_task_status(
285282 )
286283
287284 async def get_task_result (
288- self , task_id : TaskId , with_task_context : TaskContext | None
285+ self , task_id : TaskId , with_task_context : TaskContext
289286 ) -> Any :
290287 """
291288 returns: the result of the task
@@ -306,7 +303,7 @@ async def get_task_result(
306303 raise TaskCancelledError (task_id = task_id ) from exc
307304
308305 async def cancel_task (
309- self , task_id : TaskId , with_task_context : TaskContext | None
306+ self , task_id : TaskId , with_task_context : TaskContext
310307 ) -> None :
311308 """
312309 cancels the task
@@ -354,7 +351,7 @@ async def _cancel_tracked_task(
354351 async def remove_task (
355352 self ,
356353 task_id : TaskId ,
357- with_task_context : TaskContext | None ,
354+ with_task_context : TaskContext ,
358355 * ,
359356 reraise_errors : bool = True ,
360357 ) -> None :
@@ -382,7 +379,7 @@ def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId:
382379 async def _update_progress (
383380 self ,
384381 task_id : TaskId ,
385- task_context : TaskContext | None ,
382+ task_context : TaskContext ,
386383 task_progress : TaskProgress ,
387384 ) -> None :
388385 tracked_data = await self ._get_tracked_task (task_id , task_context )
@@ -420,10 +417,11 @@ async def start_task(
420417 task_name = task_name , managed_task = queried_task
421418 )
422419
420+ context_to_use = task_context or {}
423421 task_progress = TaskProgress .create (task_id = task_id )
424422 # set update callback
425423 task_progress .set_update_callback (
426- functools .partial (self ._update_progress , task_id , task_context )
424+ functools .partial (self ._update_progress , task_id , context_to_use )
427425 )
428426
429427 # bind the task with progress 0 and 1
@@ -441,7 +439,7 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
441439 tracked_task = await self ._add_task (
442440 task = async_task ,
443441 task_progress = task_progress ,
444- task_context = task_context or {} ,
442+ task_context = context_to_use ,
445443 fire_and_forget = fire_and_forget ,
446444 task_id = task_id ,
447445 )
@@ -452,10 +450,10 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
452450__all__ : tuple [str , ...] = (
453451 "TaskAlreadyRunningError" ,
454452 "TaskCancelledError" ,
453+ "TaskData" ,
455454 "TaskId" ,
456455 "TaskProgress" ,
457456 "TaskProtocol" ,
458457 "TaskStatus" ,
459458 "TasksManager" ,
460- "TaskData" ,
461459)
0 commit comments