Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ async def patch_registered_function_job(
async def run_function(
self,
*,
job_creation_task_id: TaskID | None,
function: RegisteredFunction,
pre_registered_function_job_data: PreRegisteredFunctionJobData,
pricing_spec: JobPricingSpecification | None,
Expand Down Expand Up @@ -434,7 +433,7 @@ async def run_function(
product_name=self.product_name,
function_job_id=pre_registered_function_job_data.function_job_id,
function_class=FunctionClass.PROJECT,
job_creation_task_id=job_creation_task_id,
job_creation_task_id=None,
project_job_id=study_job.id,
)

Expand All @@ -459,37 +458,14 @@ async def run_function(
product_name=self.product_name,
function_job_id=pre_registered_function_job_data.function_job_id,
function_class=FunctionClass.SOLVER,
job_creation_task_id=job_creation_task_id,
job_creation_task_id=None,
solver_job_id=solver_job.id,
)

raise UnsupportedFunctionClassError(
function_class=function.function_class,
)

async def map_function(
self,
*,
job_creation_task_id: TaskID | None,
function: RegisteredFunction,
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
job_links: JobLinks,
pricing_spec: JobPricingSpecification | None,
x_simcore_parent_project_uuid: ProjectID | None,
x_simcore_parent_node_id: NodeID | None,
) -> None:

for data in pre_registered_function_job_data_list:
await self.run_function(
job_creation_task_id=job_creation_task_id,
function=function,
pre_registered_function_job_data=data,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)

async def function_job_outputs(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@

from ..._service_function_jobs import FunctionJobService
from ..._service_functions import FunctionService
from ...celery_worker.worker_tasks.functions_tasks import function_map as map_task
from ...celery_worker.worker_tasks.functions_tasks import (
run_function as run_function_task,
)
from ...exceptions.function_errors import FunctionJobCacheNotFoundError
from ...models.domain.functions import PreRegisteredFunctionJobData
from ...models.pagination import Page, PaginationParams
from ...models.schemas.errors import ErrorGet
from ...models.schemas.jobs import JobPricingSpecification
Expand Down Expand Up @@ -454,81 +452,23 @@ async def map_function(
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
) -> RegisteredFunctionJobCollection:

task_manager = get_task_manager(request.app)
parent_project_uuid = (
x_simcore_parent_project_uuid
if isinstance(x_simcore_parent_project_uuid, ProjectID)
else None
)
parent_node_id = (
x_simcore_parent_node_id
if isinstance(x_simcore_parent_node_id, NodeID)
else None
)
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
job_links = await function_service.get_function_job_links(to_run_function, url_for)

job_inputs_list = [
await function_jobs_service.create_function_job_inputs(
function=to_run_function, function_inputs=function_inputs
)
for function_inputs in function_inputs_list
]

job_ids: list[FunctionJobID] = []
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData] = []

for job_inputs in job_inputs_list:
try:
cached_job = await function_jobs_service.get_cached_function_job(
function=to_run_function,
job_inputs=job_inputs,
)
job_ids.append(cached_job.uid)
except FunctionJobCacheNotFoundError:
data = await function_jobs_service.pre_register_function_job(
function=to_run_function,
job_inputs=job_inputs,
)
pre_registered_function_job_data_list.append(data)
job_ids.append(data.function_job_id)

# run map in celery task
job_filter = AsyncJobFilter(
user_id=user_identity.user_id,
product_name=user_identity.product_name,
client_name=ASYNC_JOB_CLIENT_NAME,
)
task_filter = TaskFilter.model_validate(job_filter.model_dump())
task_name = map_task.__name__

task_uuid = await task_manager.submit_task(
TaskMetadata(
name=task_name,
ephemeral=True,
queue=TasksQueue.API_WORKER_QUEUE,
),
task_filter=task_filter,
user_identity=user_identity,
function=to_run_function,
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=parent_project_uuid,
x_simcore_parent_node_id=parent_node_id,
)

# patch pre-registered function jobs
for data in pre_registered_function_job_data_list:
await function_jobs_service.patch_registered_function_job(
user_id=user_identity.user_id,
product_name=user_identity.product_name,
function_job_id=data.function_job_id,
function_class=to_run_function.function_class,
job_creation_task_id=TaskID(task_uuid),
for function_inputs in function_inputs_list:
job = await run_function(
request=request,
user_identity=user_identity,
to_run_function=to_run_function,
url_for=url_for,
function_inputs=function_inputs,
function_service=function_service,
function_job_service=function_jobs_service,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)
job_ids.append(job.uid)

function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(pre_registered_function_job_data_list)} inputs"
function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs"
return await web_api_rpc_client.register_function_job_collection(
function_job_collection=FunctionJobCollection(
title="Function job collection of function map",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,10 @@ async def run_function(
)

return await function_job_service.run_function(
job_creation_task_id=task_id,
function=function,
pre_registered_function_job_data=pre_registered_function_job_data,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)


async def function_map(
task: Task,
task_id: TaskID,
*,
user_identity: Identity,
function: RegisteredFunction,
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
job_links: JobLinks,
pricing_spec: JobPricingSpecification | None,
x_simcore_parent_project_uuid: ProjectID | None,
x_simcore_parent_node_id: NodeID | None,
) -> None:
assert task_id # nosec
app = get_app_server(task.app).app
function_job_service = await _assemble_function_job_service(
app=app, user_identity=user_identity
)

return await function_job_service.map_function(
job_creation_task_id=task_id,
function=function,
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...models.api_resources import JobLinks
from ...models.domain.functions import PreRegisteredFunctionJobData
from ...models.schemas.jobs import JobInputs, JobPricingSpecification
from .functions_tasks import function_map, run_function
from .functions_tasks import run_function

_logger = logging.getLogger(__name__)

Expand All @@ -43,4 +43,3 @@ def setup_worker_tasks(app: Celery) -> None:

with log_context(_logger, logging.INFO, msg="worker task registration"):
register_task(app, run_function)
register_task(app, function_map)
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ async def _patch_registered_function_job_side_effect(
registered_function_job_patch = kwargs["registered_function_job_patch"]
assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch)
job_creation_task_id = registered_function_job_patch.job_creation_task_id
assert job_creation_task_id is not None
uid = kwargs["function_job_uuid"]
return mock_registered_project_function_job.model_copy(
update={"job_creation_task_id": job_creation_task_id}
update={"job_creation_task_id": job_creation_task_id, "uid": uid}
)


Expand Down Expand Up @@ -551,3 +551,132 @@ def _default_side_effect(
].job_creation_task_id
await wait_for_task_result(client, auth, f"{task_id}")
assert side_effect_checks["headers_checked"] is True


@pytest.mark.parametrize("capture", ["run_study_function_parent_info.json"])
@pytest.mark.parametrize("mocked_app_dependencies", [None])
async def test_map_function(
app: FastAPI,
with_api_server_celery_worker: TestWorkController,
client: AsyncClient,
mock_handler_in_functions_rpc_interface: Callable[
[str, Any, Exception | None, Callable | None], MockType
],
mock_registered_project_function: RegisteredProjectFunction,
mock_registered_project_function_job: RegisteredFunctionJob,
auth: httpx.BasicAuth,
user_id: UserID,
mocked_webserver_rest_api_base: respx.MockRouter,
mocked_directorv2_rest_api_base: respx.MockRouter,
mocked_webserver_rpc_api: dict[str, MockType],
create_respx_mock_from_capture,
project_tests_dir: Path,
capture: str,
) -> None:

def _default_side_effect(
request: httpx.Request,
path_params: dict[str, Any],
capture: HttpApiCallCaptureModel,
) -> Any:
return capture.response_body

create_respx_mock_from_capture(
respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base],
capture_path=project_tests_dir / "mocks" / capture,
side_effects_callbacks=[_default_side_effect] * 50,
)

mock_handler_in_functions_rpc_interface(
"get_function_user_permissions",
FunctionUserAccessRights(
user_id=user_id,
execute=True,
read=True,
write=True,
),
None,
None,
)
mock_handler_in_functions_rpc_interface(
"get_function", mock_registered_project_function, None, None
)
mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None)

_generated_function_job_ids: list[FunctionJobID] = []

async def _register_function_job_side_effect(
generated_function_job_ids: list[FunctionJobID], *args, **kwargs
):
uid = FunctionJobID(_faker.uuid4())
generated_function_job_ids.append(uid)
return mock_registered_project_function_job.model_copy(update={"uid": uid})

mock_handler_in_functions_rpc_interface(
"register_function_job",
None,
None,
partial(_register_function_job_side_effect, _generated_function_job_ids),
)
mock_handler_in_functions_rpc_interface(
"get_functions_user_api_access_rights",
FunctionUserApiAccessRights(
user_id=user_id,
execute_functions=True,
write_functions=True,
read_functions=True,
),
None,
None,
)

async def _register_function_job_collection_side_effect(*args, **kwargs):
job_collection = kwargs["function_job_collection"]
return RegisteredFunctionJobCollection(
uid=FunctionJobID(_faker.uuid4()),
title="Test Collection",
description="A test function job collection",
job_ids=job_collection.job_ids,
created_at=datetime.datetime.now(datetime.UTC),
)

mock_handler_in_functions_rpc_interface(
"register_function_job_collection",
None,
None,
_register_function_job_collection_side_effect,
)

patch_mock = mock_handler_in_functions_rpc_interface(
"patch_registered_function_job",
None,
None,
partial(
_patch_registered_function_job_side_effect,
mock_registered_project_function_job,
),
)

_inputs = [{}, {}]
response = await client.post(
f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map",
json=_inputs,
auth=auth,
headers={
X_SIMCORE_PARENT_PROJECT_UUID: "null",
X_SIMCORE_PARENT_NODE_ID: "null",
},
)
assert response.status_code == status.HTTP_200_OK

job_collection = FunctionJobCollection.model_validate(response.json())
assert (
job_collection.job_ids == _generated_function_job_ids
), "Job ID did not preserve order or were incorrectly propagated"
celery_task_ids = {
elm.kwargs["registered_function_job_patch"].job_creation_task_id
for elm in patch_mock.call_args_list
}
assert len(celery_task_ids) == len(_inputs)
for task_id in celery_task_ids:
await wait_for_task_result(client, auth, f"{task_id}")
Loading