1313from pydantic import ValidationError
1414from servicelib .logging_utils import log_context
1515
16- from ...exceptions .errors import ConfigurationError
17- from .models import TaskContext , TaskID , TaskState , TaskStatus , TaskUUID
16+ from .models import (
17+ TaskContext ,
18+ TaskData ,
19+ TaskState ,
20+ TaskStatus ,
21+ TaskStore ,
22+ TaskUUID ,
23+ build_task_id ,
24+ )
1825
1926_logger = logging .getLogger (__name__ )
2027
21- _CELERY_INSPECT_TASK_STATUSES : Final [tuple [str , ...]] = (
22- "active" ,
23- "scheduled" ,
24- "revoked" ,
25- )
26- _CELERY_TASK_META_PREFIX : Final [str ] = "celery-task-meta-"
2728_CELERY_STATES_MAPPING : Final [dict [str , TaskState ]] = {
2829 "PENDING" : TaskState .PENDING ,
2930 "STARTED" : TaskState .PENDING ,
3435 "FAILURE" : TaskState .ERROR ,
3536 "ERROR" : TaskState .ERROR ,
3637}
37- _CELERY_TASK_ID_KEY_SEPARATOR : Final [str ] = ":"
38- _CELERY_TASK_ID_KEY_ENCODING = "utf-8"
3938
4039_MIN_PROGRESS_VALUE = 0.0
4140_MAX_PROGRESS_VALUE = 1.0
4241
4342
44- def _build_context_prefix (task_context : TaskContext ) -> list [str ]:
45- return [f"{ task_context [key ]} " for key in sorted (task_context )]
46-
47-
48- def _build_task_id_prefix (task_context : TaskContext ) -> str :
49- return _CELERY_TASK_ID_KEY_SEPARATOR .join (_build_context_prefix (task_context ))
50-
51-
52- def _build_task_id (task_context : TaskContext , task_uuid : TaskUUID ) -> TaskID :
53- return _CELERY_TASK_ID_KEY_SEPARATOR .join (
54- [_build_task_id_prefix (task_context ), f"{ task_uuid } " ]
55- )
56-
57-
5843@dataclass
5944class CeleryTaskQueueClient :
6045 _celery_app : Celery
46+ _task_store : TaskStore
6147
62- @make_async ()
63- def send_task (
48+ async def send_task (
6449 self , task_name : str , * , task_context : TaskContext , ** task_params
6550 ) -> TaskUUID :
6651 with log_context (
@@ -69,8 +54,11 @@ def send_task(
6954 msg = f"Submit { task_name = } : { task_context = } { task_params = } " ,
7055 ):
7156 task_uuid = uuid4 ()
72- task_id = _build_task_id (task_context , task_uuid )
57+ task_id = build_task_id (task_context , task_uuid )
7358 self ._celery_app .send_task (task_name , task_id = task_id , kwargs = task_params )
59+ await self ._task_store .set_task (
60+ task_id , TaskData (status = TaskState .PENDING .name )
61+ )
7462 return task_uuid
7563
7664 @staticmethod
@@ -79,25 +67,25 @@ def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
7967 with log_context (
8068 _logger ,
8169 logging .DEBUG ,
82- msg = f"Abort task { task_uuid = } : { task_context = } " ,
70+ msg = f"Abort task: { task_context = } { task_uuid = } " ,
8371 ):
84- task_id = _build_task_id (task_context , task_uuid )
72+ task_id = build_task_id (task_context , task_uuid )
8573 AbortableAsyncResult (task_id ).abort ()
8674
8775 @make_async ()
8876 def get_task_result (self , task_context : TaskContext , task_uuid : TaskUUID ) -> Any :
8977 with log_context (
9078 _logger ,
9179 logging .DEBUG ,
92- msg = f"Get task { task_uuid = } : { task_context = } result " ,
80+ msg = f"Get task result : { task_context = } { task_uuid = } " ,
9381 ):
94- task_id = _build_task_id (task_context , task_uuid )
82+ task_id = build_task_id (task_context , task_uuid )
9583 return self ._celery_app .AsyncResult (task_id ).result
9684
9785 def _get_progress_report (
9886 self , task_context : TaskContext , task_uuid : TaskUUID
9987 ) -> ProgressReport :
100- task_id = _build_task_id (task_context , task_uuid )
88+ task_id = build_task_id (task_context , task_uuid )
10189 result = self ._celery_app .AsyncResult (task_id ).result
10290 state = self ._get_state (task_context , task_uuid )
10391 if result and state == TaskState .RUNNING :
@@ -117,64 +105,28 @@ def _get_progress_report(
117105 )
118106
119107 def _get_state (self , task_context : TaskContext , task_uuid : TaskUUID ) -> TaskState :
120- task_id = _build_task_id (task_context , task_uuid )
108+ task_id = build_task_id (task_context , task_uuid )
121109 return _CELERY_STATES_MAPPING [self ._celery_app .AsyncResult (task_id ).state ]
122110
123111 @make_async ()
124112 def get_task_status (
125113 self , task_context : TaskContext , task_uuid : TaskUUID
126114 ) -> TaskStatus :
127- return TaskStatus (
128- task_uuid = task_uuid ,
129- task_state = self ._get_state (task_context , task_uuid ),
130- progress_report = self ._get_progress_report (task_context , task_uuid ),
131- )
132-
133- def _get_completed_task_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
134- search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix (task_context )
135- backend_client = self ._celery_app .backend .client
136- if hasattr (backend_client , "keys" ):
137- if keys := backend_client .keys (f"{ search_key } *" ):
138- return {
139- TaskUUID (
140- f"{ key .decode (_CELERY_TASK_ID_KEY_ENCODING ).removeprefix (search_key + _CELERY_TASK_ID_KEY_SEPARATOR )} "
141- )
142- for key in keys
143- }
144- return set ()
145- if hasattr (backend_client , "cache" ):
146- # NOTE: backend used in testing. It is a dict-like object
147- found_keys = set ()
148- for key in backend_client .cache :
149- str_key = key .decode (_CELERY_TASK_ID_KEY_ENCODING )
150- if str_key .startswith (search_key ):
151- found_keys .add (
152- TaskUUID (
153- f"{ str_key .removeprefix (search_key + _CELERY_TASK_ID_KEY_SEPARATOR )} "
154- )
155- )
156- return found_keys
157- msg = f"Unsupported backend { self ._celery_app .backend .__class__ .__name__ } "
158- raise ConfigurationError (msg = msg )
159-
160- @make_async ()
161- def get_task_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
162- task_uuids = self ._get_completed_task_uuids (task_context )
163-
164- task_id_prefix = _build_task_id_prefix (task_context )
165- inspect = self ._celery_app .control .inspect ()
166- for task_inspect_status in _CELERY_INSPECT_TASK_STATUSES :
167- tasks = getattr (inspect , task_inspect_status )() or {}
168-
169- task_uuids .update (
170- TaskUUID (
171- task_info ["id" ].removeprefix (
172- task_id_prefix + _CELERY_TASK_ID_KEY_SEPARATOR
173- )
174- )
175- for tasks_per_worker in tasks .values ()
176- for task_info in tasks_per_worker
177- if "id" in task_info
115+ with log_context (
116+ _logger ,
117+ logging .DEBUG ,
118+ msg = f"Getting task status: { task_context = } { task_uuid = } " ,
119+ ):
120+ return TaskStatus (
121+ task_uuid = task_uuid ,
122+ task_state = self ._get_state (task_context , task_uuid ),
123+ progress_report = self ._get_progress_report (task_context , task_uuid ),
178124 )
179125
180- return task_uuids
126+ async def get_task_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
127+ with log_context (
128+ _logger ,
129+ logging .DEBUG ,
130+ msg = f"Getting task uuids: { task_context = } " ,
131+ ):
132+ return await self ._task_store .get_task_uuids (task_context )
0 commit comments