88
99from ..models import (
1010 TaskContext ,
11- TaskID ,
1211 TaskMetadata ,
1312 TaskUUID ,
1413 build_task_id ,
2524_logger = logging .getLogger (__name__ )
2625
2726
28- def _build_key (task_id : TaskID ) -> str :
29- return _CELERY_TASK_INFO_PREFIX + task_id
27+ def _build_key (task_context : TaskContext , task_uuid : TaskUUID | None = None ) -> str :
28+ if task_uuid is None :
29+ return _CELERY_TASK_INFO_PREFIX + build_task_id_prefix (task_context )
30+ return _CELERY_TASK_INFO_PREFIX + build_task_id (task_context , task_uuid )
3031
3132
3233class RedisTaskInfoStore :
@@ -40,7 +41,7 @@ async def create(
4041 task_metadata : TaskMetadata ,
4142 expiry : timedelta ,
4243 ) -> None :
43- task_key = _build_key (build_task_id ( task_context , task_uuid ) )
44+ task_key = _build_key (task_context , task_uuid )
4445 await self ._redis_client_sdk .redis .hset (
4546 name = task_key ,
4647 key = _CELERY_TASK_METADATA_KEY ,
@@ -52,28 +53,24 @@ async def create(
5253 )
5354
5455 async def exists (self , task_context : TaskContext , task_uuid : TaskUUID ) -> bool :
55- n = await self ._redis_client_sdk .redis .exists (_build_key (build_task_id ( task_context , task_uuid ) )) # type: ignore
56+ n = await self ._redis_client_sdk .redis .exists (_build_key (task_context , task_uuid )) # type: ignore
5657 assert isinstance (n , int ) # nosec
5758 return n > 0
5859
5960 async def get_metadata (
6061 self , task_context : TaskContext , task_uuid : TaskUUID
6162 ) -> TaskMetadata | None :
62- result = await self ._redis_client_sdk .redis .hget (_build_key (build_task_id ( task_context , task_uuid ) ), _CELERY_TASK_METADATA_KEY ) # type: ignore
63+ result = await self ._redis_client_sdk .redis .hget (_build_key (task_context , task_uuid ), _CELERY_TASK_METADATA_KEY ) # type: ignore
6364 return TaskMetadata .model_validate_json (result ) if result else None
6465
6566 async def get_progress (
6667 self , task_context : TaskContext , task_uuid : TaskUUID
6768 ) -> ProgressReport | None :
68- result = await self ._redis_client_sdk .redis .hget (_build_key (build_task_id ( task_context , task_uuid ) ), _CELERY_TASK_PROGRESS_KEY ) # type: ignore
69+ result = await self ._redis_client_sdk .redis .hget (_build_key (task_context , task_uuid ), _CELERY_TASK_PROGRESS_KEY ) # type: ignore
6970 return ProgressReport .model_validate_json (result ) if result else None
7071
7172 async def get_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
72- search_key = (
73- _CELERY_TASK_INFO_PREFIX
74- + build_task_id_prefix (task_context )
75- + _CELERY_TASK_ID_KEY_SEPARATOR
76- )
73+ search_key = _build_key (task_context ) + _CELERY_TASK_ID_KEY_SEPARATOR
7774 keys = set ()
7875 async for key in self ._redis_client_sdk .redis .scan_iter (
7976 match = search_key + "*" , count = _CELERY_TASK_SCAN_COUNT_PER_BATCH
@@ -87,13 +84,15 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
8784 keys .add (TaskUUID (_key .removeprefix (search_key )))
8885 return keys
8986
90- async def remove (self , task_id : TaskID ) -> None :
91- await self ._redis_client_sdk .redis .delete (_build_key (task_id ))
92- AsyncResult (task_id ).forget ()
87+ async def remove (self , task_context : TaskContext , task_uuid : TaskUUID ) -> None :
88+ await self ._redis_client_sdk .redis .delete (_build_key (task_context , task_uuid )) # type: ignore
89+ AsyncResult (build_task_id ( task_context , task_uuid ) ).forget ()
9390
94- async def set_progress (self , task_id : TaskID , report : ProgressReport ) -> None :
91+ async def set_progress (
92+ self , task_context : TaskContext , task_uuid : TaskUUID , report : ProgressReport
93+ ) -> None :
9594 await self ._redis_client_sdk .redis .hset (
96- name = _build_key (task_id ),
95+ name = _build_key (task_context , task_uuid ),
9796 key = _CELERY_TASK_PROGRESS_KEY ,
9897 value = report .model_dump_json (),
9998 ) # type: ignore
0 commit comments