77
88import inspect
99from collections .abc import Callable
10+ from functools import partial
1011from pathlib import Path
1112from typing import Any
1213
2324from fastapi import FastAPI , status
2425from httpx import AsyncClient , BasicAuth , HTTPStatusError
2526from models_library .api_schemas_long_running_tasks .tasks import (
26- TaskGet ,
2727 TaskResult ,
2828 TaskStatus ,
2929)
@@ -146,6 +146,18 @@ def _(celery_app: Celery) -> None:
146146 return _
147147
148148
149+ async def _patch_registered_function_job_side_effect (
150+ mock_registered_project_function_job : RegisteredFunctionJob , * args , ** kwargs
151+ ):
152+ registered_function_job_patch = kwargs ["registered_function_job_patch" ]
153+ assert isinstance (registered_function_job_patch , RegisteredProjectFunctionJobPatch )
154+ job_creation_task_id = registered_function_job_patch .job_creation_task_id
155+ assert job_creation_task_id is not None
156+ return mock_registered_project_function_job .model_copy (
157+ update = {"job_creation_task_id" : job_creation_task_id }
158+ )
159+
160+
149161@pytest .mark .parametrize ("register_celery_tasks" , [_register_fake_run_function_task ()])
150162@pytest .mark .parametrize ("add_worker_tasks" , [False ])
151163async def test_with_fake_run_function (
@@ -210,19 +222,14 @@ async def test_with_fake_run_function(
210222 "register_function_job" , mock_registered_project_function_job , None , None
211223 )
212224
213- async def _patch_side_effect (* args , ** kwargs ):
214- registered_function_job_patch = kwargs ["registered_function_job_patch" ]
215- assert isinstance (
216- registered_function_job_patch , RegisteredProjectFunctionJobPatch
217- )
218- job_creation_task_id = registered_function_job_patch .job_creation_task_id
219- assert job_creation_task_id is not None
220- return mock_registered_project_function_job .model_copy (
221- update = {"job_creation_task_id" : job_creation_task_id }
222- )
223-
224225 mock_handler_in_functions_rpc_interface (
225- "patch_registered_function_job" , None , None , _patch_side_effect
226+ "patch_registered_function_job" ,
227+ None ,
228+ None ,
229+ partial (
230+ _patch_registered_function_job_side_effect ,
231+ mock_registered_project_function_job ,
232+ ),
226233 )
227234
228235 headers = {}
@@ -313,7 +320,9 @@ async def test_run_project_function_parent_info(
313320 app : FastAPI ,
314321 with_api_server_celery_worker : TestWorkController ,
315322 client : AsyncClient ,
316- mock_handler_in_functions_rpc_interface : Callable [[str , Any ], None ],
323+ mock_handler_in_functions_rpc_interface : Callable [
324+ [str , Any , Exception | None , Callable | None ], None
325+ ],
317326 mock_registered_project_function : RegisteredProjectFunction ,
318327 mock_registered_project_function_job : RegisteredFunctionJob ,
319328 auth : httpx .BasicAuth ,
@@ -358,13 +367,15 @@ def _default_side_effect(
358367 read = True ,
359368 write = True ,
360369 ),
370+ None ,
371+ None ,
361372 )
362373 mock_handler_in_functions_rpc_interface (
363- "get_function" , mock_registered_project_function
374+ "get_function" , mock_registered_project_function , None , None
364375 )
365- mock_handler_in_functions_rpc_interface ("find_cached_function_jobs" , [])
376+ mock_handler_in_functions_rpc_interface ("find_cached_function_jobs" , [], None , None )
366377 mock_handler_in_functions_rpc_interface (
367- "register_function_job" , mock_registered_project_function_job
378+ "register_function_job" , mock_registered_project_function_job , None , None
368379 )
369380 mock_handler_in_functions_rpc_interface (
370381 "get_functions_user_api_access_rights" ,
@@ -374,6 +385,17 @@ def _default_side_effect(
374385 write_functions = True ,
375386 read_functions = True ,
376387 ),
388+ None ,
389+ None ,
390+ )
391+ mock_handler_in_functions_rpc_interface (
392+ "patch_registered_function_job" ,
393+ None ,
394+ None ,
395+ partial (
396+ _patch_registered_function_job_side_effect ,
397+ mock_registered_project_function_job ,
398+ ),
377399 )
378400
379401 headers = {}
@@ -390,6 +412,9 @@ def _default_side_effect(
390412 )
391413 assert response .status_code == expected_status_code
392414 if response .status_code == status .HTTP_200_OK :
393- task = TaskGet .model_validate (response .json ())
394- result = await poll_task_until_done (client , auth , task .task_id )
415+ function_job = RegisteredProjectFunctionJob .model_validate (response .json ())
416+ celery_task_id = function_job .job_creation_task_id
417+ assert celery_task_id is not None
418+ # Poll until task completion and get result
419+ result = await poll_task_until_done (client , auth , celery_task_id )
395420 RegisteredProjectFunctionJob .model_validate (result .result )
0 commit comments