11from dataclasses import dataclass
2+ from typing import NamedTuple , overload
23
34import jsonschema
45from common_library .exclude import as_dict_exclude_none
1617 RegisteredFunction ,
1718 RegisteredFunctionJob ,
1819 RegisteredFunctionJobCollection ,
20+ RegisteredProjectFunctionJobPatch ,
21+ RegisteredSolverFunctionJobPatch ,
1922 SolverFunctionJob ,
23+ SolverJobID ,
24+ TaskID ,
2025)
2126from models_library .functions_errors import (
2227 FunctionExecuteAccessDeniedError ,
2328 FunctionInputsValidationError ,
29+ FunctionJobCacheNotFoundError ,
2430 FunctionsExecuteApiAccessDeniedError ,
2531 UnsupportedFunctionClassError ,
2632 UnsupportedFunctionFunctionJobClassCombinationError ,
4349from .services_rpc .wb_api_server import WbApiRpcClient
4450
4551
52+ class RegisteredFunctionJobData (NamedTuple ):
53+ function_job_id : FunctionJobID
54+ job_inputs : JobInputs
55+
56+
4657def join_inputs (
4758 default_inputs : FunctionInputs | None ,
4859 function_inputs : FunctionInputs | None ,
@@ -162,12 +173,29 @@ async def inspect_function_job(
162173 job_status = new_job_status ,
163174 )
164175
165- async def run_function_pre_check (
176+ async def create_function_job_inputs (
166177 self ,
167178 * ,
168179 function : RegisteredFunction ,
169180 function_inputs : FunctionInputs ,
170181 ) -> JobInputs :
182+ joined_inputs = join_inputs (
183+ function .default_inputs ,
184+ function_inputs ,
185+ )
186+ return JobInputs (
187+ values = joined_inputs or {},
188+ )
189+
190+ async def get_cached_function_job (
191+ self ,
192+ * ,
193+ function : RegisteredFunction ,
194+ function_inputs : FunctionInputs ,
195+ job_inputs : JobInputs ,
196+ ) -> RegisteredFunctionJob :
197+ """raises FunctionJobCacheNotFoundError if no cached job is found"""
198+
171199 user_api_access_rights = (
172200 await self ._web_rpc_client .get_functions_user_api_access_rights (
173201 user_id = self .user_id , product_name = self .product_name
@@ -190,21 +218,137 @@ async def run_function_pre_check(
190218 function_id = function .uid ,
191219 )
192220
193- joined_inputs = join_inputs (
194- function .default_inputs ,
195- function_inputs ,
196- )
221+ if cached_function_jobs := await self ._web_rpc_client .find_cached_function_jobs (
222+ function_id = function .uid ,
223+ inputs = job_inputs .values ,
224+ user_id = self .user_id ,
225+ product_name = self .product_name ,
226+ ):
227+ for cached_function_job in cached_function_jobs :
228+ job_status = await self .inspect_function_job (
229+ function = function ,
230+ function_job = cached_function_job ,
231+ )
232+ if job_status .status == RunningState .SUCCESS :
233+ return cached_function_job
234+
235+ raise FunctionJobCacheNotFoundError ()
236+
237+ async def create_registered_function_job (
238+ self ,
239+ * ,
240+ function : RegisteredFunction ,
241+ function_inputs : FunctionInputs ,
242+ job_inputs : JobInputs ,
243+ ) -> FunctionJobID :
197244
198245 if function .input_schema is not None :
199246 is_valid , validation_str = await self .validate_function_inputs (
200247 function_id = function .uid ,
201- inputs = joined_inputs ,
248+ inputs = job_inputs . values ,
202249 )
203250 if not is_valid :
204251 raise FunctionInputsValidationError (error = validation_str )
205252
206- return JobInputs (
207- values = joined_inputs or {},
253+ if function .function_class == FunctionClass .PROJECT :
254+ job = await self ._web_rpc_client .register_function_job (
255+ function_job = ProjectFunctionJob (
256+ function_uid = function .uid ,
257+ title = f"Function job of function { function .uid } " ,
258+ description = function .description ,
259+ inputs = job_inputs .values ,
260+ outputs = None ,
261+ project_job_id = None ,
262+ job_creation_task_id = None ,
263+ ),
264+ user_id = self .user_id ,
265+ product_name = self .product_name ,
266+ )
267+
268+ elif function .function_class == FunctionClass .SOLVER :
269+ job = await self ._web_rpc_client .register_function_job (
270+ function_job = SolverFunctionJob (
271+ function_uid = function .uid ,
272+ title = f"Function job of function { function .uid } " ,
273+ description = function .description ,
274+ inputs = job_inputs .values ,
275+ outputs = None ,
276+ solver_job_id = None ,
277+ job_creation_task_id = None ,
278+ ),
279+ user_id = self .user_id ,
280+ product_name = self .product_name ,
281+ )
282+ else :
283+ raise UnsupportedFunctionClassError (
284+ function_class = function .function_class ,
285+ )
286+
287+ return job .uid
288+
289+ @overload
290+ async def patch_registered_function_job (
291+ self ,
292+ * ,
293+ user_id : UserID ,
294+ product_name : ProductName ,
295+ function_job_id : FunctionJobID ,
296+ function_class : FunctionClass ,
297+ job_creation_task_id : TaskID | None ,
298+ project_job_id : ProjectID | None ,
299+ ) -> RegisteredFunctionJob : ...
300+
301+ @overload
302+ async def patch_registered_function_job (
303+ self ,
304+ * ,
305+ user_id : UserID ,
306+ product_name : ProductName ,
307+ function_job_id : FunctionJobID ,
308+ function_class : FunctionClass ,
309+ job_creation_task_id : TaskID | None ,
310+ solver_job_id : SolverJobID | None ,
311+ ) -> RegisteredFunctionJob : ...
312+
313+ async def patch_registered_function_job (
314+ self ,
315+ * ,
316+ user_id : UserID ,
317+ product_name : ProductName ,
318+ function_job_id : FunctionJobID ,
319+ function_class : FunctionClass ,
320+ job_creation_task_id : TaskID | None ,
321+ project_job_id : ProjectID | None = None ,
322+ solver_job_id : SolverJobID | None = None ,
323+ ) -> RegisteredFunctionJob :
324+ # Only allow one of project_job_id or solver_job_id depending on function_class
325+ if function_class == FunctionClass .PROJECT :
326+ patch = RegisteredProjectFunctionJobPatch (
327+ title = None ,
328+ description = None ,
329+ inputs = None ,
330+ outputs = None ,
331+ job_creation_task_id = job_creation_task_id ,
332+ project_job_id = project_job_id ,
333+ )
334+ elif function_class == FunctionClass .SOLVER :
335+ patch = RegisteredSolverFunctionJobPatch (
336+ title = None ,
337+ description = None ,
338+ inputs = None ,
339+ outputs = None ,
340+ job_creation_task_id = job_creation_task_id ,
341+ solver_job_id = solver_job_id ,
342+ )
343+ else :
344+ raise UnsupportedFunctionClassError (
345+ function_class = function_class ,
346+ )
347+ return await self ._web_rpc_client .patch_registered_function_job (
348+ user_id = user_id ,
349+ product_name = product_name ,
350+ function_job_id = function_job_id ,
351+ registered_function_job_patch = patch ,
208352 )
209353
210354 async def run_function (
@@ -217,20 +361,7 @@ async def run_function(
217361 x_simcore_parent_project_uuid : NodeID | None ,
218362 x_simcore_parent_node_id : NodeID | None ,
219363 ) -> RegisteredFunctionJob :
220-
221- if cached_function_jobs := await self ._web_rpc_client .find_cached_function_jobs (
222- function_id = function .uid ,
223- inputs = job_inputs .values ,
224- user_id = self .user_id ,
225- product_name = self .product_name ,
226- ):
227- for cached_function_job in cached_function_jobs :
228- job_status = await self .inspect_function_job (
229- function = function ,
230- function_job = cached_function_job ,
231- )
232- if job_status .status == RunningState .SUCCESS :
233- return cached_function_job
364+ """N.B. this function does not check access rights. Use get_cached_function_job for that"""
234365
235366 if function .function_class == FunctionClass .PROJECT :
236367 study_job = await self ._job_service .create_studies_job (
@@ -306,7 +437,7 @@ async def map_function(
306437 ) -> RegisteredFunctionJobCollection :
307438
308439 job_inputs = [
309- await self .run_function_pre_check (
440+ await self .create_registered_function_job (
310441 function = function ,
311442 function_inputs = inputs ,
312443 )
@@ -335,3 +466,25 @@ async def map_function(
335466 user_id = self .user_id ,
336467 product_name = self .product_name ,
337468 )
469+ function_jobs = [
470+ await self .run_function (
471+ function = function ,
472+ job_inputs = inputs ,
473+ pricing_spec = pricing_spec ,
474+ job_links = job_links ,
475+ x_simcore_parent_project_uuid = x_simcore_parent_project_uuid ,
476+ x_simcore_parent_node_id = x_simcore_parent_node_id ,
477+ )
478+ for inputs in job_inputs
479+ ]
480+
481+ function_job_collection_description = f"Function job collection of map of function { function .uid } with { len (function_inputs_list )} inputs"
482+ return await self ._web_rpc_client .register_function_job_collection (
483+ function_job_collection = FunctionJobCollection (
484+ title = "Function job collection of function map" ,
485+ description = function_job_collection_description ,
486+ job_ids = [function_job .uid for function_job in function_jobs ],
487+ ),
488+ user_id = self .user_id ,
489+ product_name = self .product_name ,
490+ )
0 commit comments