1010from common_library .async_tools import make_async
1111from models_library .progress_bar import ProgressReport
1212from servicelib .celery .models import (
13+ TASK_QUEUE_DEFAULT ,
1314 Task ,
1415 TaskContext ,
1516 TaskID ,
1617 TaskInfoStore ,
1718 TaskMetadata ,
19+ TaskName ,
20+ TaskQueue ,
1821 TaskState ,
1922 TaskStatus ,
2023 TaskUUID ,
@@ -39,63 +42,69 @@ class CeleryTaskManager:
3942
4043 async def send_task (
4144 self ,
42- task_metadata : TaskMetadata ,
45+ name : TaskName ,
46+ context : TaskContext ,
4347 * ,
44- task_context : TaskContext ,
45- ** task_params ,
48+ is_ephemeral : bool = False ,
49+ queue : TaskQueue = TASK_QUEUE_DEFAULT ,
50+ ** params ,
4651 ) -> TaskUUID :
4752 with log_context (
4853 _logger ,
4954 logging .DEBUG ,
50- msg = f"Send { task_metadata . name = } : { task_context = } { task_params = } " ,
55+ msg = f"Send { name = } : { context = } { params = } " ,
5156 ):
5257 task_uuid = uuid4 ()
53- task_id = build_task_id (task_context , task_uuid )
58+ task_id = build_task_id (context , task_uuid )
5459 self ._celery_app .send_task (
55- task_metadata . name ,
60+ name ,
5661 task_id = task_id ,
57- kwargs = {"task_id" : task_id } | task_params ,
58- queue = task_metadata . queue ,
62+ kwargs = {"task_id" : task_id } | params ,
63+ queue = queue ,
5964 )
6065
6166 expiry = (
6267 self ._celery_settings .CELERY_EPHEMERAL_RESULT_EXPIRES
63- if task_metadata . ephemeral
68+ if is_ephemeral
6469 else self ._celery_settings .CELERY_RESULT_EXPIRES
6570 )
6671 await self ._task_info_store .create_task (
67- task_id , task_metadata , expiry = expiry
72+ task_id ,
73+ TaskMetadata (
74+ name = name ,
75+ ephemeral = is_ephemeral ,
76+ queue = queue ,
77+ ),
78+ expiry = expiry ,
6879 )
6980 return task_uuid
7081
7182 @make_async ()
7283 def _abort_task (self , task_id : TaskID ) -> None :
7384 AbortableAsyncResult (task_id , app = self ._celery_app ).abort ()
7485
75- async def cancel_task (self , task_context : TaskContext , task_uuid : TaskUUID ) -> None :
86+ async def cancel_task (self , context : TaskContext , task_uuid : TaskUUID ) -> None :
7687 with log_context (
7788 _logger ,
7889 logging .DEBUG ,
79- msg = f"task cancellation: { task_context = } { task_uuid = } " ,
90+ msg = f"task cancellation: { context = } { task_uuid = } " ,
8091 ):
81- task_id = build_task_id (task_context , task_uuid )
82- if not (await self .get_task_status (task_context , task_uuid )).is_done :
92+ task_id = build_task_id (context , task_uuid )
93+ if not (await self .get_task_status (context , task_uuid )).is_done :
8394 await self ._abort_task (task_id )
8495 await self ._task_info_store .remove_task (task_id )
8596
8697 @make_async ()
8798 def _forget_task (self , task_id : TaskID ) -> None :
8899 AbortableAsyncResult (task_id , app = self ._celery_app ).forget ()
89100
90- async def get_task_result (
91- self , task_context : TaskContext , task_uuid : TaskUUID
92- ) -> Any :
101+ async def get_task_result (self , context : TaskContext , task_uuid : TaskUUID ) -> Any :
93102 with log_context (
94103 _logger ,
95104 logging .DEBUG ,
96- msg = f"Get task result: { task_context = } { task_uuid = } " ,
105+ msg = f"Get task result: { context = } { task_uuid = } " ,
97106 ):
98- task_id = build_task_id (task_context , task_uuid )
107+ task_id = build_task_id (context , task_uuid )
99108 async_result = self ._celery_app .AsyncResult (task_id )
100109 result = async_result .result
101110 if async_result .ready ():
@@ -105,15 +114,15 @@ async def get_task_result(
105114 await self ._task_info_store .remove_task (task_id )
106115 return result
107116
108- async def _get_task_progress_report (
109- self , task_context : TaskContext , task_uuid : TaskUUID , task_state : TaskState
117+ async def _get_progress_report (
118+ self , context : TaskContext , task_uuid : TaskUUID , state : TaskState
110119 ) -> ProgressReport :
111- if task_state in (TaskState .STARTED , TaskState .RETRY , TaskState .ABORTED ):
112- task_id = build_task_id (task_context , task_uuid )
120+ if state in (TaskState .STARTED , TaskState .RETRY , TaskState .ABORTED ):
121+ task_id = build_task_id (context , task_uuid )
113122 progress = await self ._task_info_store .get_task_progress (task_id )
114123 if progress is not None :
115124 return progress
116- if task_state in (
125+ if state in (
117126 TaskState .SUCCESS ,
118127 TaskState .FAILURE ,
119128 ):
@@ -131,30 +140,30 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState:
131140 return TaskState (self ._celery_app .AsyncResult (task_id ).state )
132141
133142 async def get_task_status (
134- self , task_context : TaskContext , task_uuid : TaskUUID
143+ self , context : TaskContext , task_uuid : TaskUUID
135144 ) -> TaskStatus :
136145 with log_context (
137146 _logger ,
138147 logging .DEBUG ,
139- msg = f"Getting task status: { task_context = } { task_uuid = } " ,
148+ msg = f"Getting task status: { context = } { task_uuid = } " ,
140149 ):
141- task_id = build_task_id (task_context , task_uuid )
150+ task_id = build_task_id (context , task_uuid )
142151 task_state = await self ._get_task_celery_state (task_id )
143152 return TaskStatus (
144153 task_uuid = task_uuid ,
145154 task_state = task_state ,
146- progress_report = await self ._get_task_progress_report (
147- task_context , task_uuid , task_state
155+ progress_report = await self ._get_progress_report (
156+ context , task_uuid , task_state
148157 ),
149158 )
150159
151- async def list_tasks (self , task_context : TaskContext ) -> list [Task ]:
160+ async def list_tasks (self , context : TaskContext ) -> list [Task ]:
152161 with log_context (
153162 _logger ,
154163 logging .DEBUG ,
155- msg = f"Listing tasks: { task_context = } " ,
164+ msg = f"Listing tasks: { context = } " ,
156165 ):
157- return await self ._task_info_store .list_tasks (task_context )
166+ return await self ._task_info_store .list_tasks (context )
158167
159168 async def set_task_progress (self , task_id : TaskID , report : ProgressReport ) -> None :
160169 await self ._task_info_store .set_task_progress (
0 commit comments