11from dataclasses import dataclass
2- from typing import overload
2+ from typing import Annotated
33
44import jsonschema
55from common_library .exclude import as_dict_exclude_none
1414 ProjectFunctionJob ,
1515 RegisteredFunction ,
1616 RegisteredFunctionJob ,
17- RegisteredFunctionJobPatch ,
1817 RegisteredProjectFunctionJobPatch ,
1918 RegisteredSolverFunctionJobPatch ,
2019 SolverFunctionJob ,
21- SolverJobID ,
22- TaskID ,
2320)
2421from models_library .functions_errors import (
2522 FunctionInputsValidationError ,
2623 UnsupportedFunctionClassError ,
2724)
2825from models_library .products import ProductName
29- from models_library .projects import ProjectID
3026from models_library .projects_nodes_io import NodeID
3127from models_library .rest_pagination import PageMetaInfoLimitOffset , PageOffsetInt
3228from models_library .rpc_pagination import PageLimitInt
3329from models_library .users import UserID
34- from pydantic import TypeAdapter , ValidationError
30+ from pydantic import Field , TypeAdapter , ValidationError , validate_call
3531from simcore_service_api_server ._service_functions import FunctionService
3632from simcore_service_api_server .services_rpc .storage import StorageService
3733
3834from ._service_jobs import JobService
3935from .models .api_resources import JobLinks
40- from .models .domain .functions import PreRegisteredFunctionJobData
36+ from .models .domain .functions import (
37+ PreRegisteredFunctionJobData ,
38+ ProjectFunctionJobPatch ,
39+ SolverFunctionJobPatch ,
40+ )
4141from .models .schemas .jobs import JobInputs , JobPricingSpecification
4242from .services_http .webserver import AuthSession
4343from .services_rpc .wb_api_server import WbApiRpcClient
@@ -202,81 +202,51 @@ async def pre_register_function_job(
202202 for job , input_ in zip (jobs , job_inputs )
203203 ]
204204
205- @overload
206- async def patch_registered_function_job (
207- self ,
208- * ,
209- user_id : UserID ,
210- product_name : ProductName ,
211- function_job_id : FunctionJobID ,
212- function_class : FunctionClass ,
213- job_creation_task_id : TaskID | None ,
214- ) -> RegisteredFunctionJob : ...
215-
216- @overload
217- async def patch_registered_function_job (
218- self ,
219- * ,
220- user_id : UserID ,
221- product_name : ProductName ,
222- function_job_id : FunctionJobID ,
223- function_class : FunctionClass ,
224- job_creation_task_id : TaskID | None ,
225- project_job_id : ProjectID | None ,
226- ) -> RegisteredFunctionJob : ...
227-
228- @overload
229- async def patch_registered_function_job (
230- self ,
231- * ,
232- user_id : UserID ,
233- product_name : ProductName ,
234- function_job_id : FunctionJobID ,
235- function_class : FunctionClass ,
236- job_creation_task_id : TaskID | None ,
237- solver_job_id : SolverJobID | None ,
238- ) -> RegisteredFunctionJob : ...
239-
205+ @validate_call
240206 async def patch_registered_function_job (
241207 self ,
242208 * ,
243209 user_id : UserID ,
244210 product_name : ProductName ,
245- function_job_id : FunctionJobID ,
246- function_class : FunctionClass ,
247- job_creation_task_id : TaskID | None ,
248- project_job_id : ProjectID | None = None ,
249- solver_job_id : SolverJobID | None = None ,
250- ) -> RegisteredFunctionJob :
251- # Only allow one of project_job_id or solver_job_id depending on function_class
252- patch : RegisteredFunctionJobPatch
253- if function_class == FunctionClass .PROJECT :
254- patch = RegisteredProjectFunctionJobPatch (
255- title = None ,
256- description = None ,
257- inputs = None ,
258- outputs = None ,
259- job_creation_task_id = job_creation_task_id ,
260- project_job_id = project_job_id ,
261- )
262- elif function_class == FunctionClass .SOLVER :
263- patch = RegisteredSolverFunctionJobPatch (
264- title = None ,
265- description = None ,
266- inputs = None ,
267- outputs = None ,
268- job_creation_task_id = job_creation_task_id ,
269- solver_job_id = solver_job_id ,
270- )
271- else :
272- raise UnsupportedFunctionClassError (
273- function_class = function_class ,
274- )
211+ patches : Annotated [
212+ list [ProjectFunctionJobPatch ] | list [SolverFunctionJobPatch ],
213+ Field (max_length = 50 , min_length = 1 ),
214+ ],
215+ ) -> list [RegisteredFunctionJob ]:
216+ patch_inputs = []
217+ for patch in patches :
218+ if patch .function_class == FunctionClass .PROJECT :
219+ assert isinstance (patch , ProjectFunctionJobPatch ) # nosec
220+ patch_inputs .append (
221+ RegisteredProjectFunctionJobPatch (
222+ title = None ,
223+ description = None ,
224+ inputs = None ,
225+ outputs = None ,
226+ job_creation_task_id = patch .job_creation_task_id ,
227+ project_job_id = patch .project_job_id ,
228+ )
229+ )
230+ elif patch .function_class == FunctionClass .SOLVER :
231+ assert isinstance (patch , SolverFunctionJobPatch ) # nosec
232+ patch_inputs .append (
233+ RegisteredSolverFunctionJobPatch (
234+ title = None ,
235+ description = None ,
236+ inputs = None ,
237+ outputs = None ,
238+ job_creation_task_id = patch .job_creation_task_id ,
239+ solver_job_id = patch .solver_job_id ,
240+ )
241+ )
242+ else :
243+ raise UnsupportedFunctionClassError (
244+ function_class = patch .function_class ,
245+ )
275246 return await self ._web_rpc_client .patch_registered_function_job (
276247 user_id = user_id ,
277248 product_name = product_name ,
278- function_job_id = function_job_id ,
279- registered_function_job_patch = patch ,
249+ registered_function_job_patch_inputs = patch_inputs ,
280250 )
281251
282252 async def run_function (
@@ -305,14 +275,19 @@ async def run_function(
305275 job_id = study_job .id ,
306276 pricing_spec = pricing_spec ,
307277 )
308- return await self .patch_registered_function_job (
278+ registered_jobs = await self .patch_registered_function_job (
309279 user_id = self .user_id ,
310280 product_name = self .product_name ,
311- function_job_id = pre_registered_function_job_data .function_job_id ,
312- function_class = FunctionClass .PROJECT ,
313- job_creation_task_id = None ,
314- project_job_id = study_job .id ,
281+ patches = [
282+ ProjectFunctionJobPatch (
283+ function_job_id = pre_registered_function_job_data .function_job_id ,
284+ job_creation_task_id = None ,
285+ project_job_id = study_job .id ,
286+ )
287+ ],
315288 )
289+ assert len (registered_jobs ) == 1
290+ return registered_jobs [0 ]
316291
317292 if function .function_class == FunctionClass .SOLVER :
318293 solver_job = await self ._job_service .create_solver_job (
@@ -330,14 +305,19 @@ async def run_function(
330305 job_id = solver_job .id ,
331306 pricing_spec = pricing_spec ,
332307 )
333- return await self .patch_registered_function_job (
308+ registered_jobs = await self .patch_registered_function_job (
334309 user_id = self .user_id ,
335310 product_name = self .product_name ,
336- function_job_id = pre_registered_function_job_data .function_job_id ,
337- function_class = FunctionClass .SOLVER ,
338- job_creation_task_id = None ,
339- solver_job_id = solver_job .id ,
311+ patches = [
312+ SolverFunctionJobPatch (
313+ function_job_id = pre_registered_function_job_data .function_job_id ,
314+ job_creation_task_id = None ,
315+ solver_job_id = solver_job .id ,
316+ )
317+ ],
340318 )
319+ assert len (registered_jobs ) == 1
320+ return registered_jobs [0 ]
341321
342322 raise UnsupportedFunctionClassError (
343323 function_class = function .function_class ,
0 commit comments