11import contextlib
2- import json
32import logging
43from typing import Any , Final
54from uuid import uuid4
1211from models_library .progress_bar import ProgressReport
1312from pydantic import ValidationError
1413from servicelib .logging_utils import log_context
15- from servicelib .redis ._client import RedisClientSDK
1614
17- from .models import TaskContext , TaskID , TaskState , TaskStatus , TaskUUID
15+ from .models import (
16+ TaskContext ,
17+ TaskData ,
18+ TaskState ,
19+ TaskStatus ,
20+ TaskStore ,
21+ TaskUUID ,
22+ build_task_id ,
23+ )
1824
1925_logger = logging .getLogger (__name__ )
2026
21- _CELERY_TASK_META_PREFIX : Final [str ] = "celery-task-meta-"
2227_CELERY_STATES_MAPPING : Final [dict [str , TaskState ]] = {
2328 "PENDING" : TaskState .PENDING ,
2429 "STARTED" : TaskState .PENDING ,
2934 "FAILURE" : TaskState .ERROR ,
3035 "ERROR" : TaskState .ERROR ,
3136}
32- _CELERY_TASK_ID_KEY_SEPARATOR : Final [str ] = ":"
33- _CELERY_TASK_SCAN_COUNT_PER_BATCH : Final [int ] = 10000
3437
3538_MIN_PROGRESS_VALUE = 0.0
3639_MAX_PROGRESS_VALUE = 100.0
3740
3841
39- def _build_context_prefix (task_context : TaskContext ) -> list [str ]:
40- return [f"{ task_context [key ]} " for key in sorted (task_context )]
41-
42-
43- def _build_task_id_prefix (task_context : TaskContext ) -> str :
44- return _CELERY_TASK_ID_KEY_SEPARATOR .join (_build_context_prefix (task_context ))
45-
46-
47- def _build_task_id (task_context : TaskContext , task_uuid : TaskUUID ) -> TaskID :
48- return _CELERY_TASK_ID_KEY_SEPARATOR .join (
49- [_build_task_id_prefix (task_context ), f"{ task_uuid } " ]
50- )
51-
52-
5342class CeleryTaskQueueClient :
54- def __init__ (self , celery_app : Celery , redis_client_sdk : RedisClientSDK ) -> None :
43+ def __init__ (self , celery_app : Celery , task_store : TaskStore ) -> None :
5544 self ._celery_app = celery_app
56- self ._redis_client_sdk = redis_client_sdk
45+ self ._task_store = task_store
5746
5847 async def send_task (
5948 self , task_name : str , * , task_context : TaskContext , ** task_params
6049 ) -> TaskUUID :
6150 task_uuid = uuid4 ()
62- task_id = _build_task_id (task_context , task_uuid )
51+ task_id = build_task_id (task_context , task_uuid )
6352 with log_context (
6453 _logger ,
6554 logging .DEBUG ,
6655 msg = f"Submitting task { task_name } : { task_id = } { task_params = } " ,
6756 ):
6857 self ._celery_app .send_task (task_name , task_id = task_id , kwargs = task_params )
69- await self ._redis_client_sdk .redis .set (
70- _CELERY_TASK_META_PREFIX + task_id ,
71- json .dumps (
72- {
73- "status" : "PENDING" ,
74- }
75- ),
58+ await self ._task_store .set_task (
59+ task_id , TaskData (status = TaskState .PENDING .name )
7660 )
7761 return task_uuid
7862
7963 @make_async ()
8064 def abort_task ( # pylint: disable=R6301
8165 self , task_context : TaskContext , task_uuid : TaskUUID
8266 ) -> None :
83- task_id = _build_task_id (task_context , task_uuid )
67+ task_id = build_task_id (task_context , task_uuid )
8468 _logger .info ("Aborting task %s" , task_id )
8569 AbortableAsyncResult (task_id ).abort ()
8670
8771 @make_async ()
8872 def get_task_result (self , task_context : TaskContext , task_uuid : TaskUUID ) -> Any :
89- task_id = _build_task_id (task_context , task_uuid )
73+ task_id = build_task_id (task_context , task_uuid )
9074 return self ._celery_app .AsyncResult (task_id ).result
9175
9276 def _get_progress_report (
9377 self , task_context : TaskContext , task_uuid : TaskUUID
9478 ) -> ProgressReport :
95- task_id = _build_task_id (task_context , task_uuid )
79+ task_id = build_task_id (task_context , task_uuid )
9680 result = self ._celery_app .AsyncResult (task_id ).result
9781 state = self ._get_state (task_context , task_uuid )
9882 if result and state == TaskState .RUNNING :
@@ -108,7 +92,7 @@ def _get_progress_report(
10892 return ProgressReport (actual_value = _MIN_PROGRESS_VALUE )
10993
11094 def _get_state (self , task_context : TaskContext , task_uuid : TaskUUID ) -> TaskState :
111- task_id = _build_task_id (task_context , task_uuid )
95+ task_id = build_task_id (task_context , task_uuid )
11296 return _CELERY_STATES_MAPPING [self ._celery_app .AsyncResult (task_id ).state ]
11397
11498 @make_async ()
@@ -122,14 +106,4 @@ def get_task_status(
122106 )
123107
124108 async def get_task_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
125- search_key = (
126- _CELERY_TASK_META_PREFIX
127- + _build_task_id_prefix (task_context )
128- + _CELERY_TASK_ID_KEY_SEPARATOR
129- )
130- keys = set ()
131- async for key in self ._redis_client_sdk .redis .scan_iter (
132- match = search_key + "*" , count = _CELERY_TASK_SCAN_COUNT_PER_BATCH
133- ):
134- keys .add (TaskUUID (f"{ key } " .removeprefix (search_key )))
135- return keys
109+ return await self ._task_store .get_task_uuids (task_context )
0 commit comments