|
26 | 26 | from models_library.functions import ( |
27 | 27 | FunctionClass, |
28 | 28 | FunctionID, |
| 29 | + FunctionInputsList, |
29 | 30 | FunctionJobCollection, |
30 | 31 | FunctionJobID, |
| 32 | + FunctionJobList, |
| 33 | + FunctionJobStatus, |
31 | 34 | FunctionUserAccessRights, |
32 | 35 | FunctionUserApiAccessRights, |
33 | 36 | RegisteredFunction, |
34 | 37 | RegisteredFunctionJob, |
35 | 38 | RegisteredFunctionJobCollection, |
36 | 39 | RegisteredProjectFunction, |
37 | 40 | RegisteredProjectFunctionJob, |
38 | | - RegisteredProjectFunctionJobPatch, |
| 41 | + RegisteredProjectFunctionJobPatchInputList, |
| 42 | + RegisteredSolverFunctionJobPatchInputList, |
39 | 43 | ) |
| 44 | +from models_library.products import ProductName |
40 | 45 | from models_library.projects import ProjectID |
41 | 46 | from models_library.users import UserID |
42 | 47 | from pytest_mock import MockType |
@@ -151,15 +156,48 @@ def _(celery_app: Celery) -> None: |
151 | 156 |
|
152 | 157 |
|
153 | 158 | async def _patch_registered_function_job_side_effect( |
154 | | - mock_registered_project_function_job: RegisteredFunctionJob, *args, **kwargs |
| 159 | + mock_registered_project_function_job: RegisteredFunctionJob, |
| 160 | + product_name: ProductName, |
| 161 | + user_id: UserID, |
| 162 | + registered_function_job_patch_inputs: ( |
| 163 | + RegisteredProjectFunctionJobPatchInputList |
| 164 | + | RegisteredSolverFunctionJobPatchInputList |
| 165 | + ), |
155 | 166 | ): |
156 | | - registered_function_job_patch = kwargs["registered_function_job_patch"] |
157 | | - assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch) |
158 | | - job_creation_task_id = registered_function_job_patch.job_creation_task_id |
159 | | - uid = kwargs["function_job_uuid"] |
160 | | - return mock_registered_project_function_job.model_copy( |
161 | | - update={"job_creation_task_id": job_creation_task_id, "uid": uid} |
162 | | - ) |
| 167 | + return [ |
| 168 | + mock_registered_project_function_job.model_copy( |
| 169 | + update={ |
| 170 | + "job_creation_task_id": patch.patch.job_creation_task_id, |
| 171 | + "uid": patch.uid, |
| 172 | + } |
| 173 | + ) |
| 174 | + for patch in registered_function_job_patch_inputs |
| 175 | + ] |
| 176 | + |
| 177 | + |
| 178 | +async def _find_cached_function_jobs_side_effect( |
| 179 | + *, |
| 180 | + user_id: UserID, |
| 181 | + product_name: ProductName, |
| 182 | + function_id: FunctionID, |
| 183 | + inputs: FunctionInputsList, |
| 184 | + status_filter: list[FunctionJobStatus] | None, |
| 185 | +): |
| 186 | + return [None] * len(inputs) |
| 187 | + |
| 188 | + |
| 189 | +async def _register_function_job_side_effect( |
| 190 | + registered_function_job: RegisteredFunctionJob, |
| 191 | + user_id: UserID, |
| 192 | + function_jobs: FunctionJobList, |
| 193 | + product_name: ProductName, |
| 194 | +): |
| 195 | + return [ |
| 196 | + registered_function_job.model_copy( |
| 197 | + update={"uid": FunctionJobID(_faker.uuid4())} |
| 198 | + ) |
| 199 | + for _ in function_jobs |
| 200 | + ] |
163 | 201 |
|
164 | 202 |
|
165 | 203 | @pytest.mark.parametrize("register_celery_tasks", [_register_fake_run_function_task()]) |
@@ -216,10 +254,14 @@ async def test_with_fake_run_function( |
216 | 254 | "get_function", return_value=fake_registered_project_function |
217 | 255 | ) |
218 | 256 | mock_handler_in_functions_rpc_interface( |
219 | | - "find_cached_function_jobs", return_value=[] |
| 257 | + "find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect |
220 | 258 | ) |
221 | 259 | mock_handler_in_functions_rpc_interface( |
222 | | - "register_function_job", return_value=fake_registered_project_function_job |
| 260 | + "register_function_job", |
| 261 | + side_effect=partial( |
| 262 | + _register_function_job_side_effect, |
| 263 | + fake_registered_project_function_job, |
| 264 | + ), |
223 | 265 | ) |
224 | 266 |
|
225 | 267 | mock_handler_in_functions_rpc_interface( |
|
0 commit comments