11from dataclasses import dataclass
2- from typing import overload
2+ from typing import Final , overload
33
44import jsonschema
55from common_library .exclude import as_dict_exclude_none
3838from models_library .rpc_pagination import PageLimitInt
3939from models_library .users import UserID
4040from pydantic import ValidationError
41+ from servicelib .celery .models import TaskUUID
42+ from servicelib .celery .task_manager import TaskManager
4143from simcore_service_api_server ._service_functions import FunctionService
4244from simcore_service_api_server .services_rpc .storage import StorageService
4345from sqlalchemy .ext .asyncio import AsyncEngine
4446
4547from ._service_jobs import JobService
48+ from .api .routes .tasks import _get_task_filter
4649from .exceptions .function_errors import (
4750 FunctionJobCacheNotFoundError ,
48- FunctionJobProjectMissingError ,
4951)
5052from .models .api_resources import JobLinks
5153from .models .domain .functions import PreRegisteredFunctionJobData
5254from .models .schemas .jobs import JobInputs , JobPricingSpecification
5355from .services_http .webserver import AuthSession
5456from .services_rpc .wb_api_server import WbApiRpcClient
5557
58+ _JOB_CREATION_TASK_STATUS_PREFIX : Final [str ] = "JOB_CREATION_TASK_STATUS_"
59+ _JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS : Final [str ] = (
60+ f"{ _JOB_CREATION_TASK_STATUS_PREFIX } NOT_YET_SCHEDULED"
61+ )
62+
5663
5764def join_inputs (
5865 default_inputs : FunctionInputs | None ,
@@ -68,6 +75,23 @@ def join_inputs(
6875 return {** default_inputs , ** function_inputs }
6976
7077
78+ async def _celery_task_status (
79+ job_creation_task_id : TaskID | None ,
80+ task_manager : TaskManager ,
81+ user_id : UserID ,
82+ product_name : ProductName ,
83+ ) -> FunctionJobStatus :
84+ if job_creation_task_id is None :
85+ return FunctionJobStatus (status = _JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS )
86+ task_filter = _get_task_filter (user_id , product_name )
87+ task_status = await task_manager .get_task_status (
88+ task_uuid = TaskUUID (job_creation_task_id ), task_filter = task_filter
89+ )
90+ return FunctionJobStatus (
91+ status = f"{ _JOB_CREATION_TASK_STATUS_PREFIX } { task_status .task_state } "
92+ )
93+
94+
7195@dataclass (frozen = True , kw_only = True )
7296class FunctionJobService :
7397 user_id : UserID
@@ -159,7 +183,10 @@ async def validate_function_inputs(
159183 )
160184
161185 async def inspect_function_job (
162- self , function : RegisteredFunction , function_job : RegisteredFunctionJob
186+ self ,
187+ function : RegisteredFunction ,
188+ function_job : RegisteredFunctionJob ,
189+ task_manager : TaskManager ,
163190 ) -> FunctionJobStatus :
164191 """Raises FunctionJobProjectNotRegisteredError if no project is associated with job"""
165192 stored_job_status = await self ._web_rpc_client .get_function_job_status (
@@ -176,15 +203,25 @@ async def inspect_function_job(
176203 and function_job .function_class == FunctionClass .PROJECT
177204 ):
178205 if function_job .project_job_id is None :
179- raise FunctionJobProjectMissingError
206+ return await _celery_task_status (
207+ job_creation_task_id = function_job .job_creation_task_id ,
208+ task_manager = task_manager ,
209+ user_id = self .user_id ,
210+ product_name = self .product_name ,
211+ )
180212 job_status = await self ._job_service .inspect_study_job (
181213 job_id = function_job .project_job_id ,
182214 )
183215 elif (function .function_class == FunctionClass .SOLVER ) and (
184216 function_job .function_class == FunctionClass .SOLVER
185217 ):
186218 if function_job .solver_job_id is None :
187- raise FunctionJobProjectMissingError
219+ return await _celery_task_status (
220+ job_creation_task_id = function_job .job_creation_task_id ,
221+ task_manager = task_manager ,
222+ user_id = self .user_id ,
223+ product_name = self .product_name ,
224+ )
188225 job_status = await self ._job_service .inspect_solver_job (
189226 solver_key = function .solver_key ,
190227 version = function .solver_version ,
@@ -225,6 +262,7 @@ async def get_cached_function_job(
225262 * ,
226263 function : RegisteredFunction ,
227264 job_inputs : JobInputs ,
265+ task_manager : TaskManager ,
228266 ) -> RegisteredFunctionJob :
229267 """
230268 N.B. this function checks access rights
@@ -266,6 +304,7 @@ async def get_cached_function_job(
266304 job_status = await self .inspect_function_job (
267305 function = function ,
268306 function_job = cached_function_job ,
307+ task_manager = task_manager ,
269308 )
270309 if job_status .status == RunningState .SUCCESS :
271310 return cached_function_job
@@ -476,18 +515,15 @@ async def function_job_outputs(
476515 product_name : ProductName ,
477516 stored_job_outputs : FunctionOutputs | None ,
478517 async_pg_engine : AsyncEngine ,
518+ task_manager : TaskManager ,
479519 ) -> FunctionOutputs :
480520
481521 if stored_job_outputs is not None :
482522 return stored_job_outputs
483523
484- try :
485- job_status = await self .inspect_function_job (
486- function = function ,
487- function_job = function_job ,
488- )
489- except FunctionJobProjectMissingError :
490- return None
524+ job_status = await self .inspect_function_job (
525+ function = function , function_job = function_job , task_manager = task_manager
526+ )
491527
492528 if job_status .status != RunningState .SUCCESS :
493529 return None
0 commit comments