Skip to content

Commit 20088b0

Browse files
committed
fix yet another test
1 parent f4bae45 commit 20088b0

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

services/api-server/tests/unit/api_functions/celery/test_functions_celery.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -655,19 +655,27 @@ def _default_side_effect(
655655
"get_function", return_value=fake_registered_project_function
656656
)
657657
mock_handler_in_functions_rpc_interface(
658-
"find_cached_function_jobs", return_value=[]
658+
"find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect
659659
)
660660

661661
mock_handler_in_projects_rpc_interface("mark_project_as_job", return_value=None)
662662

663663
_generated_function_job_ids: list[FunctionJobID] = []
664664

665665
async def _register_function_job_side_effect(
666-
generated_function_job_ids: list[FunctionJobID], *args, **kwargs
666+
generated_function_job_ids: list,
667+
user_id: UserID,
668+
function_jobs: FunctionJobList,
669+
product_name: ProductName,
667670
):
668-
uid = FunctionJobID(_faker.uuid4())
669-
generated_function_job_ids.append(uid)
670-
return fake_registered_project_function_job.model_copy(update={"uid": uid})
671+
registered_jobs = []
672+
for _ in function_jobs:
673+
uid = FunctionJobID(_faker.uuid4())
674+
generated_function_job_ids.append(uid)
675+
registered_jobs.append(
676+
fake_registered_project_function_job.model_copy(update={"uid": uid})
677+
)
678+
return registered_jobs
671679

672680
mock_handler_in_functions_rpc_interface(
673681
"register_function_job",
@@ -726,10 +734,17 @@ async def _register_function_job_collection_side_effect(*args, **kwargs):
726734
assert (
727735
job_collection.job_ids == _generated_function_job_ids
728736
), "Job ID did not preserve order or were incorrectly propagated"
729-
celery_task_ids = {
730-
elm.kwargs["registered_function_job_patch"].job_creation_task_id
731-
for elm in patch_mock.call_args_list
732-
}
737+
738+
celery_task_ids = set()
739+
for args in patch_mock.call_args_list:
740+
inputs = args.kwargs["registered_function_job_patch_inputs"]
741+
celery_task_ids = celery_task_ids.union(
742+
{
743+
input_.patch.job_creation_task_id
744+
for input_ in inputs
745+
if input_.patch.job_creation_task_id
746+
}
747+
)
733748
assert len(celery_task_ids) == len(_inputs)
734749
for task_id in celery_task_ids:
735750
await _wait_for_task_result(client, auth, f"{task_id}")

0 commit comments

Comments
 (0)