@@ -655,19 +655,27 @@ def _default_side_effect(
655655 "get_function" , return_value = fake_registered_project_function
656656 )
657657 mock_handler_in_functions_rpc_interface (
658- "find_cached_function_jobs" , return_value = []
658+ "find_cached_function_jobs" , side_effect = _find_cached_function_jobs_side_effect
659659 )
660660
661661 mock_handler_in_projects_rpc_interface ("mark_project_as_job" , return_value = None )
662662
663663 _generated_function_job_ids : list [FunctionJobID ] = []
664664
665665 async def _register_function_job_side_effect (
666- generated_function_job_ids : list [FunctionJobID ], * args , ** kwargs
666+ generated_function_job_ids : list ,
667+ user_id : UserID ,
668+ function_jobs : FunctionJobList ,
669+ product_name : ProductName ,
667670 ):
668- uid = FunctionJobID (_faker .uuid4 ())
669- generated_function_job_ids .append (uid )
670- return fake_registered_project_function_job .model_copy (update = {"uid" : uid })
671+ registered_jobs = []
672+ for _ in function_jobs :
673+ uid = FunctionJobID (_faker .uuid4 ())
674+ generated_function_job_ids .append (uid )
675+ registered_jobs .append (
676+ fake_registered_project_function_job .model_copy (update = {"uid" : uid })
677+ )
678+ return registered_jobs
671679
672680 mock_handler_in_functions_rpc_interface (
673681 "register_function_job" ,
@@ -726,10 +734,17 @@ async def _register_function_job_collection_side_effect(*args, **kwargs):
726734 assert (
727735 job_collection .job_ids == _generated_function_job_ids
728736 ), "Job ID did not preserve order or were incorrectly propagated"
729- celery_task_ids = {
730- elm .kwargs ["registered_function_job_patch" ].job_creation_task_id
731- for elm in patch_mock .call_args_list
732- }
737+
738+ celery_task_ids = set ()
739+ for args in patch_mock .call_args_list :
740+ inputs = args .kwargs ["registered_function_job_patch_inputs" ]
741+ celery_task_ids = celery_task_ids .union (
742+ {
743+ input_ .patch .job_creation_task_id
744+ for input_ in inputs
745+ if input_ .patch .job_creation_task_id
746+ }
747+ )
733748 assert len (celery_task_ids ) == len (_inputs )
734749 for task_id in celery_task_ids :
735750 await _wait_for_task_result (client , auth , f"{ task_id } " )
0 commit comments