@@ -92,6 +92,11 @@ async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> No
9292 task_id = build_task_id (task_context , task_uuid )
9393 return await self ._abort_task (task_id )
9494
95+ @make_async ()
96+ def _get_result (self , task_context : TaskContext , task_uuid : TaskUUID ) -> Any :
97+ task_id = build_task_id (task_context , task_uuid )
98+ return self ._celery_app .AsyncResult (task_id ).result
99+
95100 async def get_task_result (
96101 self , task_context : TaskContext , task_uuid : TaskUUID
97102 ) -> Any :
@@ -109,12 +114,8 @@ async def get_task_result(
109114 await self ._task_store .remove (task_id )
110115 return result
111116
112- def _get_progress_report (
113- self , task_context : TaskContext , task_uuid : TaskUUID
114- ) -> ProgressReport :
115- task_id = build_task_id (task_context , task_uuid )
116- result = self ._celery_app .AsyncResult (task_id ).result
117- state = self ._get_state (task_context , task_uuid )
117+ @staticmethod
118+ async def _get_progress_report (state , result ) -> ProgressReport :
118119 if result and state == TaskState .RUNNING :
119120 with contextlib .suppress (ValidationError ):
120121 # avoids exception if result is not a ProgressReport (or overwritten by a Celery's state update)
@@ -144,10 +145,12 @@ async def get_task_status(
144145 logging .DEBUG ,
145146 msg = f"Getting task status: { task_context = } { task_uuid = } " ,
146147 ):
148+ state = await self ._get_state (task_context , task_uuid )
149+ result = await self ._get_result (task_context , task_uuid )
147150 return TaskStatus (
148151 task_uuid = task_uuid ,
149- task_state = await self . _get_state ( task_context , task_uuid ) ,
150- progress_report = self ._get_progress_report (task_context , task_uuid ),
152+ task_state = state ,
153+ progress_report = await self ._get_progress_report (state , result ),
151154 )
152155
153156 async def get_task_uuids (self , task_context : TaskContext ) -> set [TaskUUID ]:
0 commit comments