1515from servicelib .background_task import create_periodic_task
1616from servicelib .logging_utils import log_catch
1717
18+ from ._store .base import BaseStore
19+ from ._store .in_memory import InMemoryStore
1820from .errors import (
1921 TaskAlreadyRunningError ,
2022 TaskCancelledError ,
3840
3941RegisteredTaskName : TypeAlias = str
4042Namespace : TypeAlias = str
41- TrackedTaskGroupDict : TypeAlias = dict [TaskId , TrackedTask ]
4243TaskContext : TypeAlias = dict [str , Any ]
4344
4445
@@ -68,28 +69,28 @@ async def _await_task(task: asyncio.Task) -> None:
6869 await task
6970
7071
71- def _get_tasks_to_remove (
72- tracked_tasks : TrackedTaskGroupDict ,
72+ async def _get_tasks_to_remove (
73+ tracked_tasks : BaseStore ,
7374 stale_task_detect_timeout_s : PositiveFloat ,
7475) -> list [TaskId ]:
7576 utc_now = datetime .datetime .now (tz = datetime .UTC )
7677
7778 tasks_to_remove : list [TaskId ] = []
7879
79- for task_id , tracked_task in tracked_tasks .items ():
80+ for tracked_task in await tracked_tasks .list ():
8081 if tracked_task .fire_and_forget :
8182 continue
8283
8384 if tracked_task .last_status_check is None :
8485 # the task just added or never received a poll request
8586 elapsed_from_start = (utc_now - tracked_task .started ).seconds
8687 if elapsed_from_start > stale_task_detect_timeout_s :
87- tasks_to_remove .append (task_id )
88+ tasks_to_remove .append (tracked_task . task_id )
8889 else :
8990 # the task status was already queried by the client
9091 elapsed_from_last_poll = (utc_now - tracked_task .last_status_check ).seconds
9192 if elapsed_from_last_poll > stale_task_detect_timeout_s :
92- tasks_to_remove .append (task_id )
93+ tasks_to_remove .append (tracked_task . task_id )
9394 return tasks_to_remove
9495
9596
@@ -103,10 +104,11 @@ def __init__(
103104 stale_task_check_interval : datetime .timedelta ,
104105 stale_task_detect_timeout : datetime .timedelta ,
105106 namespace : Namespace = _DEFAULT_NAMESPACE ,
107+ # TODO: inject a Redis connection
106108 ):
107109 self .namespace = namespace
108110 # Task groups: Every taskname maps to multiple asyncio.Task within TrackedTask model
109- self ._tracked_tasks : TrackedTaskGroupDict = {}
111+ self ._tracked_tasks : BaseStore = InMemoryStore ()
110112
111113 self .stale_task_check_interval = stale_task_check_interval
112114 self .stale_task_detect_timeout_s : PositiveFloat = (
@@ -125,7 +127,7 @@ async def setup(self) -> None:
125127 async def teardown (self ) -> None :
126128 task_ids_to_remove : deque [TaskId ] = deque ()
127129
128- for tracked_task in self ._tracked_tasks .values ():
130+ for tracked_task in await self ._tracked_tasks .list ():
129131 task_ids_to_remove .append (tracked_task .task_id )
130132
131133 for task_id in task_ids_to_remove :
@@ -155,7 +157,7 @@ async def _stale_tasks_monitor_worker(self) -> None:
155157 # Since we own the client, we assume (for now) this
156158 # will not be the case.
157159
158- tasks_to_remove = _get_tasks_to_remove (
160+ tasks_to_remove = await _get_tasks_to_remove (
159161 self ._tracked_tasks , self .stale_task_detect_timeout_s
160162 )
161163
@@ -169,25 +171,28 @@ async def _stale_tasks_monitor_worker(self) -> None:
169171 _logger .warning (
170172 "Removing stale task '%s' with status '%s'" ,
171173 task_id ,
172- self .get_task_status (task_id , with_task_context = None ).model_dump_json (),
174+ (
175+ await self .get_task_status (task_id , with_task_context = None )
176+ ).model_dump_json (),
173177 )
174178 await self .remove_task (
175179 task_id , with_task_context = None , reraise_errors = False
176180 )
177181
178- def list_tasks (self , with_task_context : TaskContext | None ) -> list [TaskBase ]:
182+ async def list_tasks (self , with_task_context : TaskContext | None ) -> list [TaskBase ]:
179183 if not with_task_context :
180184 return [
181- TaskBase (task_id = task .task_id ) for task in self ._tracked_tasks .values ()
185+ TaskBase (task_id = task .task_id )
186+ for task in (await self ._tracked_tasks .list ())
182187 ]
183188
184189 return [
185190 TaskBase (task_id = task .task_id )
186- for task in self ._tracked_tasks .values ( )
191+ for task in ( await self ._tracked_tasks .list () )
187192 if task .task_context == with_task_context
188193 ]
189194
190- def _add_task (
195+ async def _add_task (
191196 self ,
192197 task : asyncio .Task ,
193198 task_progress : TaskProgress ,
@@ -204,24 +209,24 @@ def _add_task(
204209 task_context = task_context ,
205210 fire_and_forget = fire_and_forget ,
206211 )
207- self ._tracked_tasks [ task_id ] = tracked_task
212+ await self ._tracked_tasks . set ( task_id , tracked_task )
208213
209214 return tracked_task
210215
211- def _get_tracked_task (
216+ async def _get_tracked_task (
212217 self , task_id : TaskId , with_task_context : TaskContext | None
213218 ) -> TrackedTask :
214- if task_id not in self ._tracked_tasks :
215- raise TaskNotFoundError (task_id = task_id )
219+ task = await self ._tracked_tasks .get (task_id )
216220
217- task = self ._tracked_tasks [task_id ]
221+ if task is None :
222+ raise TaskNotFoundError (task_id = task_id )
218223
219224 if with_task_context and task .task_context != with_task_context :
220225 raise TaskNotFoundError (task_id = task_id )
221226
222227 return task
223228
224- def get_task_status (
229+ async def get_task_status (
225230 self , task_id : TaskId , with_task_context : TaskContext | None
226231 ) -> TaskStatus :
227232 """
@@ -230,7 +235,9 @@ def get_task_status(
230235
231236 raises TaskNotFoundError if the task cannot be found
232237 """
233- tracked_task : TrackedTask = self ._get_tracked_task (task_id , with_task_context )
238+ tracked_task : TrackedTask = await self ._get_tracked_task (
239+ task_id , with_task_context
240+ )
234241 tracked_task .last_status_check = datetime .datetime .now (tz = datetime .UTC )
235242
236243 task = tracked_task .task
@@ -244,7 +251,7 @@ def get_task_status(
244251 }
245252 )
246253
247- def get_task_result (
254+ async def get_task_result (
248255 self , task_id : TaskId , with_task_context : TaskContext | None
249256 ) -> Any :
250257 """
@@ -254,7 +261,7 @@ def get_task_result(
254261 raises TaskCancelledError if the task was cancelled
255262 raises TaskNotCompletedError if the task is not completed
256263 """
257- tracked_task = self ._get_tracked_task (task_id , with_task_context )
264+ tracked_task = await self ._get_tracked_task (task_id , with_task_context )
258265
259266 try :
260267 return tracked_task .task .result ()
@@ -273,7 +280,7 @@ async def cancel_task(
273280
274281 raises TaskNotFoundError if the task cannot be found
275282 """
276- tracked_task = self ._get_tracked_task (task_id , with_task_context )
283+ tracked_task = await self ._get_tracked_task (task_id , with_task_context )
277284 await self ._cancel_tracked_task (tracked_task .task , task_id , reraise_errors = True )
278285
279286 @staticmethod
@@ -317,7 +324,7 @@ async def remove_task(
317324 ) -> None :
318325 """cancels and removes task"""
319326 try :
320- tracked_task = self ._get_tracked_task (task_id , with_task_context )
327+ tracked_task = await self ._get_tracked_task (task_id , with_task_context )
321328 except TaskNotFoundError :
322329 if reraise_errors :
323330 raise
@@ -327,13 +334,13 @@ async def remove_task(
327334 tracked_task .task , task_id , reraise_errors = reraise_errors
328335 )
329336 finally :
330- del self ._tracked_tasks [ task_id ]
337+ await self ._tracked_tasks . delete ( task_id )
331338
332339 def _get_task_id (self , task_name : str , * , is_unique : bool ) -> TaskId :
333340 unique_part = "unique" if is_unique else f"{ uuid4 ()} "
334341 return f"{ self .namespace } .{ task_name } .{ unique_part } "
335342
336- def start_task (
343+ async def start_task (
337344 self ,
338345 registered_task_name : RegisteredTaskName ,
339346 * ,
@@ -358,9 +365,10 @@ def start_task(
358365 task_id = self ._get_task_id (task_name , is_unique = unique )
359366
360367 # only one unique task can be running
361- if unique and task_id in self ._tracked_tasks :
368+ queried_task = await self ._tracked_tasks .get (task_id )
369+ if unique and queried_task is not None :
362370 raise TaskAlreadyRunningError (
363- task_name = task_name , managed_task = self . _tracked_tasks [ task_id ]
371+ task_name = task_name , managed_task = queried_task
364372 )
365373
366374 task_progress = TaskProgress .create (task_id = task_id )
@@ -377,7 +385,7 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
377385 _progress_task (task_progress , task ), name = task_name
378386 )
379387
380- tracked_task = self ._add_task (
388+ tracked_task = await self ._add_task (
381389 task = async_task ,
382390 task_progress = task_progress ,
383391 task_context = task_context or {},
0 commit comments