@@ -122,7 +122,9 @@ def __init__(
122122 self .redis_settings = redis_settings
123123
124124 self ._stale_tasks_monitor_task : asyncio .Task | None = None
125+ self ._started_event_stale_tasks_monitor_task = asyncio .Event ()
125126 self ._cancelled_tasks_removal_task : asyncio .Task | None = None
127+ self ._started_event_cancelled_tasks_removal_task = asyncio .Event ()
126128 self .redis_client_sdk : RedisClientSDK | None = None
127129
128130 async def setup (self ) -> None :
@@ -142,11 +144,13 @@ async def setup(self) -> None:
142144 interval = self .stale_task_check_interval ,
143145 task_name = f"{ __name__ } .{ self ._stale_tasks_monitor_worker .__name__ } " ,
144146 )
147+ await self ._started_event_stale_tasks_monitor_task .wait ()
145148 self ._cancelled_tasks_removal_task = create_periodic_task (
146149 task = self ._cancelled_tasks_removal_worker ,
147150 interval = _CANCEL_TASKS_CHECK_INTERVAL ,
148151 task_name = f"{ __name__ } .{ self ._cancelled_tasks_removal_worker .__name__ } " ,
149152 )
153+ await self ._started_event_cancelled_tasks_removal_task .wait ()
150154
151155 async def teardown (self ) -> None :
152156 for tracked_task in await self ._tasks_data .list_tasks_data ():
@@ -157,15 +161,11 @@ async def teardown(self) -> None:
157161
158162 if self ._stale_tasks_monitor_task :
159163 with log_catch (_logger , reraise = False ):
160- await cancel_wait_task (
161- self ._stale_tasks_monitor_task , max_delay = _CANCEL_TASK_TIMEOUT
162- )
164+ await cancel_wait_task (self ._stale_tasks_monitor_task )
163165
164166 if self ._cancelled_tasks_removal_task :
165167 with log_catch (_logger , reraise = False ):
166- await cancel_wait_task (
167- self ._cancelled_tasks_removal_task , max_delay = _CANCEL_TASK_TIMEOUT
168- )
168+ await cancel_wait_task (self ._cancelled_tasks_removal_task )
169169
170170 if self .redis_client_sdk is not None :
171171 await self .redis_client_sdk .shutdown ()
@@ -189,6 +189,8 @@ async def _stale_tasks_monitor_worker(self) -> None:
189189 # Since we own the client, we assume (for now) this
190190 # will not be the case.
191191
192+ self ._started_event_stale_tasks_monitor_task .set ()
193+
192194 tasks_to_remove = await _get_tasks_to_remove (
193195 self ._tasks_data , self .stale_task_detect_timeout_s
194196 )
@@ -216,8 +218,10 @@ async def _cancelled_tasks_removal_worker(self) -> None:
216218 tasks can be cancelled by the client, but they can run in differente processes
217219 once there is an entry in the cancelled store, attempt to cancel the task
218220 """
221+ self ._started_event_cancelled_tasks_removal_task .set ()
219222
220- for task_id , task_context in (await self ._tasks_data .get_cancelled ()).items ():
223+ cancelled_tasks = await self ._tasks_data .get_cancelled ()
224+ for task_id , task_context in cancelled_tasks .items ():
221225 await self .remove_task (task_id , task_context )
222226
223227 async def list_tasks (self , with_task_context : TaskContext | None ) -> list [TaskBase ]:
0 commit comments