1414
1515from .models import (
1616 TaskContext ,
17- TaskID ,
1817 TaskInfoStore ,
1918 TaskMetadata ,
2019 TaskState ,
@@ -46,12 +45,11 @@ async def send_task(
4645 with log_context (
4746 _logger ,
4847 logging .DEBUG ,
49- msg = f"Submit { task_name = } : { task_context = } { task_params = } " ,
48+ msg = f"Submit { task_metadata . name = } : { task_context = } { task_params = } " ,
5049 ):
5150 task_uuid = uuid4 ()
52- task_metadata = task_metadata or TaskMetadata ()
5351 self ._celery_app .send_task (
54- task_name ,
52+ task_metadata . name ,
5553 task_id = build_task_id (task_context , task_uuid ),
5654 kwargs = task_params ,
5755 queue = task_metadata .queue .value ,
@@ -68,17 +66,18 @@ async def send_task(
6866 return task_uuid
6967
7068 @make_async ()
71- def _abort_task (self , task_id : TaskID ) -> None :
72- AbortableAsyncResult (task_id , app = self ._celery_app ).abort ()
69+ def _abort_task (self , task_context : TaskContext , task_uuid : TaskUUID ) -> None :
70+ AbortableAsyncResult (
71+ build_task_id (task_context , task_uuid ), app = self ._celery_app
72+ ).abort ()
7373
7474 async def abort_task (self , task_context : TaskContext , task_uuid : TaskUUID ) -> None :
7575 with log_context (
7676 _logger ,
7777 logging .DEBUG ,
7878 msg = f"Abort task: { task_context = } { task_uuid = } " ,
7979 ):
80- task_id = build_task_id (task_context , task_uuid )
81- await self ._abort_task (task_id )
80+ await self ._abort_task (task_context , task_uuid )
8281
8382 async def get_task_result (
8483 self , task_context : TaskContext , task_uuid : TaskUUID
@@ -92,16 +91,18 @@ async def get_task_result(
9291 async_result = self ._celery_app .AsyncResult (task_id )
9392 result = async_result .result
9493 if async_result .ready ():
95- task_metadata = await self ._task_store .get_metadata (task_id )
94+ task_metadata = await self ._task_store .get_metadata (
95+ task_context , task_uuid
96+ )
9697 if task_metadata is not None and task_metadata .ephemeral :
97- await self ._task_store .remove (task_id )
98+ await self ._task_store .remove (task_context , task_uuid )
9899 return result
99100
100101 async def _get_progress_report (
101- self , task_id : TaskID , state : TaskState
102+ self , task_context : TaskContext , task_uuid : TaskUUID , state : TaskState
102103 ) -> ProgressReport :
103104 if state in (TaskState .STARTED , TaskState .RETRY , TaskState .ABORTED ):
104- progress = await self ._task_store .get_progress (task_id )
105+ progress = await self ._task_store .get_progress (task_context , task_uuid )
105106 if progress is not None :
106107 return progress
107108 if state in (
0 commit comments