Skip to content

Commit ff8b50a

Browse files
committed
fix tests
1 parent bf6bb7e commit ff8b50a

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

services/api-server/tests/unit/api_functions/celery/test_functions.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import inspect
99
from collections.abc import Callable
10+
from functools import partial
1011
from pathlib import Path
1112
from typing import Any
1213

@@ -23,7 +24,6 @@
2324
from fastapi import FastAPI, status
2425
from httpx import AsyncClient, BasicAuth, HTTPStatusError
2526
from models_library.api_schemas_long_running_tasks.tasks import (
26-
TaskGet,
2727
TaskResult,
2828
TaskStatus,
2929
)
@@ -146,6 +146,18 @@ def _(celery_app: Celery) -> None:
146146
return _
147147

148148

149+
async def _patch_registered_function_job_side_effect(
150+
mock_registered_project_function_job: RegisteredFunctionJob, *args, **kwargs
151+
):
152+
registered_function_job_patch = kwargs["registered_function_job_patch"]
153+
assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch)
154+
job_creation_task_id = registered_function_job_patch.job_creation_task_id
155+
assert job_creation_task_id is not None
156+
return mock_registered_project_function_job.model_copy(
157+
update={"job_creation_task_id": job_creation_task_id}
158+
)
159+
160+
149161
@pytest.mark.parametrize("register_celery_tasks", [_register_fake_run_function_task()])
150162
@pytest.mark.parametrize("add_worker_tasks", [False])
151163
async def test_with_fake_run_function(
@@ -210,19 +222,14 @@ async def test_with_fake_run_function(
210222
"register_function_job", mock_registered_project_function_job, None, None
211223
)
212224

213-
async def _patch_side_effect(*args, **kwargs):
214-
registered_function_job_patch = kwargs["registered_function_job_patch"]
215-
assert isinstance(
216-
registered_function_job_patch, RegisteredProjectFunctionJobPatch
217-
)
218-
job_creation_task_id = registered_function_job_patch.job_creation_task_id
219-
assert job_creation_task_id is not None
220-
return mock_registered_project_function_job.model_copy(
221-
update={"job_creation_task_id": job_creation_task_id}
222-
)
223-
224225
mock_handler_in_functions_rpc_interface(
225-
"patch_registered_function_job", None, None, _patch_side_effect
226+
"patch_registered_function_job",
227+
None,
228+
None,
229+
partial(
230+
_patch_registered_function_job_side_effect,
231+
mock_registered_project_function_job,
232+
),
226233
)
227234

228235
headers = {}
@@ -313,7 +320,9 @@ async def test_run_project_function_parent_info(
313320
app: FastAPI,
314321
with_api_server_celery_worker: TestWorkController,
315322
client: AsyncClient,
316-
mock_handler_in_functions_rpc_interface: Callable[[str, Any], None],
323+
mock_handler_in_functions_rpc_interface: Callable[
324+
[str, Any, Exception | None, Callable | None], None
325+
],
317326
mock_registered_project_function: RegisteredProjectFunction,
318327
mock_registered_project_function_job: RegisteredFunctionJob,
319328
auth: httpx.BasicAuth,
@@ -358,13 +367,15 @@ def _default_side_effect(
358367
read=True,
359368
write=True,
360369
),
370+
None,
371+
None,
361372
)
362373
mock_handler_in_functions_rpc_interface(
363-
"get_function", mock_registered_project_function
374+
"get_function", mock_registered_project_function, None, None
364375
)
365-
mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [])
376+
mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None)
366377
mock_handler_in_functions_rpc_interface(
367-
"register_function_job", mock_registered_project_function_job
378+
"register_function_job", mock_registered_project_function_job, None, None
368379
)
369380
mock_handler_in_functions_rpc_interface(
370381
"get_functions_user_api_access_rights",
@@ -374,6 +385,17 @@ def _default_side_effect(
374385
write_functions=True,
375386
read_functions=True,
376387
),
388+
None,
389+
None,
390+
)
391+
mock_handler_in_functions_rpc_interface(
392+
"patch_registered_function_job",
393+
None,
394+
None,
395+
partial(
396+
_patch_registered_function_job_side_effect,
397+
mock_registered_project_function_job,
398+
),
377399
)
378400

379401
headers = {}
@@ -390,6 +412,9 @@ def _default_side_effect(
390412
)
391413
assert response.status_code == expected_status_code
392414
if response.status_code == status.HTTP_200_OK:
393-
task = TaskGet.model_validate(response.json())
394-
result = await poll_task_until_done(client, auth, task.task_id)
415+
function_job = RegisteredProjectFunctionJob.model_validate(response.json())
416+
celery_task_id = function_job.job_creation_task_id
417+
assert celery_task_id is not None
418+
# Poll until task completion and get result
419+
result = await poll_task_until_done(client, auth, celery_task_id)
395420
RegisteredProjectFunctionJob.model_validate(result.result)

0 commit comments

Comments
 (0)