Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ async def patch_registered_function_job(
async def run_function(
self,
*,
job_creation_task_id: TaskID | None,
function: RegisteredFunction,
pre_registered_function_job_data: PreRegisteredFunctionJobData,
pricing_spec: JobPricingSpecification | None,
Expand Down Expand Up @@ -434,7 +433,7 @@ async def run_function(
product_name=self.product_name,
function_job_id=pre_registered_function_job_data.function_job_id,
function_class=FunctionClass.PROJECT,
job_creation_task_id=job_creation_task_id,
job_creation_task_id=None,
project_job_id=study_job.id,
)

Expand All @@ -459,37 +458,14 @@ async def run_function(
product_name=self.product_name,
function_job_id=pre_registered_function_job_data.function_job_id,
function_class=FunctionClass.SOLVER,
job_creation_task_id=job_creation_task_id,
job_creation_task_id=None,
solver_job_id=solver_job.id,
)

raise UnsupportedFunctionClassError(
function_class=function.function_class,
)

async def map_function(
self,
*,
job_creation_task_id: TaskID | None,
function: RegisteredFunction,
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
job_links: JobLinks,
pricing_spec: JobPricingSpecification | None,
x_simcore_parent_project_uuid: ProjectID | None,
x_simcore_parent_node_id: NodeID | None,
) -> None:

for data in pre_registered_function_job_data_list:
await self.run_function(
job_creation_task_id=job_creation_task_id,
function=function,
pre_registered_function_job_data=data,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)

async def function_job_outputs(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
from models_library.users import UserID
from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue
from servicelib.fastapi.dependencies import get_reverse_url_mapper
from servicelib.utils import limited_gather

from ..._service_function_jobs import FunctionJobService
from ..._service_functions import FunctionService
from ...celery_worker.worker_tasks.functions_tasks import function_map as map_task
from ...celery_worker.worker_tasks.functions_tasks import (
run_function as run_function_task,
)
from ...exceptions.function_errors import FunctionJobCacheNotFoundError
from ...models.domain.functions import PreRegisteredFunctionJobData
from ...models.pagination import Page, PaginationParams
from ...models.schemas.errors import ErrorGet
from ...models.schemas.jobs import JobPricingSpecification
Expand Down Expand Up @@ -454,86 +453,42 @@ async def map_function(
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
) -> RegisteredFunctionJobCollection:

task_manager = get_task_manager(request.app)
parent_project_uuid = (
x_simcore_parent_project_uuid
if isinstance(x_simcore_parent_project_uuid, ProjectID)
else None
)
parent_node_id = (
x_simcore_parent_node_id
if isinstance(x_simcore_parent_node_id, NodeID)
else None
)
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
job_links = await function_service.get_function_job_links(to_run_function, url_for)

job_inputs_list = [
await function_jobs_service.create_function_job_inputs(
function=to_run_function, function_inputs=function_inputs
async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID:
result = await run_function(
request=request,
user_identity=user_identity,
to_run_function=to_run_function,
url_for=url_for,
function_inputs=function_inputs,
function_service=function_service,
function_job_service=function_jobs_service,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)
for function_inputs in function_inputs_list
]

job_ids: list[FunctionJobID] = []
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData] = []

for job_inputs in job_inputs_list:
try:
cached_job = await function_jobs_service.get_cached_function_job(
function=to_run_function,
job_inputs=job_inputs,
)
job_ids.append(cached_job.uid)
except FunctionJobCacheNotFoundError:
data = await function_jobs_service.pre_register_function_job(
function=to_run_function,
job_inputs=job_inputs,
)
pre_registered_function_job_data_list.append(data)
job_ids.append(data.function_job_id)

# run map in celery task
job_filter = AsyncJobFilter(
user_id=user_identity.user_id,
product_name=user_identity.product_name,
client_name=ASYNC_JOB_CLIENT_NAME,
)
task_filter = TaskFilter.model_validate(job_filter.model_dump())
task_name = map_task.__name__

task_uuid = await task_manager.submit_task(
TaskMetadata(
name=task_name,
ephemeral=True,
queue=TasksQueue.API_WORKER_QUEUE,
),
task_filter=task_filter,
user_identity=user_identity,
function=to_run_function,
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=parent_project_uuid,
x_simcore_parent_node_id=parent_node_id,
return result.uid

# Run all tasks concurrently, allowing them to complete even if some fail
results = await limited_gather(
*[
_run_single_function(function_inputs)
for function_inputs in function_inputs_list
],
reraise=False,
limit=10,
)

# patch pre-registered function jobs
for data in pre_registered_function_job_data_list:
await function_jobs_service.patch_registered_function_job(
user_id=user_identity.user_id,
product_name=user_identity.product_name,
function_job_id=data.function_job_id,
function_class=to_run_function.function_class,
job_creation_task_id=TaskID(task_uuid),
)
# Check if any tasks raised exceptions and raise the first one found
for result in results:
if isinstance(result, BaseException):
raise result

function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(pre_registered_function_job_data_list)} inputs"
# At this point, all results are FunctionJobID since we've checked for exceptions
function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs"
return await web_api_rpc_client.register_function_job_collection(
function_job_collection=FunctionJobCollection(
title="Function job collection of function map",
description=function_job_collection_description,
job_ids=job_ids,
job_ids=results, # type: ignore
),
user_id=user_identity.user_id,
product_name=user_identity.product_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,10 @@ async def run_function(
)

return await function_job_service.run_function(
job_creation_task_id=task_id,
function=function,
pre_registered_function_job_data=pre_registered_function_job_data,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)


async def function_map(
task: Task,
task_id: TaskID,
*,
user_identity: Identity,
function: RegisteredFunction,
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
job_links: JobLinks,
pricing_spec: JobPricingSpecification | None,
x_simcore_parent_project_uuid: ProjectID | None,
x_simcore_parent_node_id: NodeID | None,
) -> None:
assert task_id # nosec
app = get_app_server(task.app).app
function_job_service = await _assemble_function_job_service(
app=app, user_identity=user_identity
)

return await function_job_service.map_function(
job_creation_task_id=task_id,
function=function,
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
pricing_spec=pricing_spec,
job_links=job_links,
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
x_simcore_parent_node_id=x_simcore_parent_node_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...models.api_resources import JobLinks
from ...models.domain.functions import PreRegisteredFunctionJobData
from ...models.schemas.jobs import JobInputs, JobPricingSpecification
from .functions_tasks import function_map, run_function
from .functions_tasks import run_function

_logger = logging.getLogger(__name__)

Expand All @@ -43,4 +43,3 @@ def setup_worker_tasks(app: Celery) -> None:

with log_context(_logger, logging.INFO, msg="worker task registration"):
register_task(app, run_function)
register_task(app, function_map)
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ async def _patch_registered_function_job_side_effect(
registered_function_job_patch = kwargs["registered_function_job_patch"]
assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch)
job_creation_task_id = registered_function_job_patch.job_creation_task_id
assert job_creation_task_id is not None
uid = kwargs["function_job_uuid"]
return mock_registered_project_function_job.model_copy(
update={"job_creation_task_id": job_creation_task_id}
update={"job_creation_task_id": job_creation_task_id, "uid": uid}
)


Expand Down Expand Up @@ -551,3 +551,135 @@ def _default_side_effect(
].job_creation_task_id
await wait_for_task_result(client, auth, f"{task_id}")
assert side_effect_checks["headers_checked"] is True


@pytest.mark.parametrize("mocked_app_dependencies", [None])
async def test_map_function(
app: FastAPI,
with_api_server_celery_worker: TestWorkController,
client: AsyncClient,
mock_handler_in_functions_rpc_interface: Callable[
[str, Any, Exception | None, Callable | None], MockType
],
mock_registered_project_function: RegisteredProjectFunction,
mock_registered_project_function_job: RegisteredFunctionJob,
auth: httpx.BasicAuth,
user_id: UserID,
mocked_webserver_rest_api_base: respx.MockRouter,
mocked_directorv2_rest_api_base: respx.MockRouter,
mocked_webserver_rpc_api: dict[str, MockType],
create_respx_mock_from_capture,
project_tests_dir: Path,
) -> None:

# arrange
_capture = "run_study_function_parent_info.json"

def _default_side_effect(
request: httpx.Request,
path_params: dict[str, Any],
capture: HttpApiCallCaptureModel,
) -> Any:
return capture.response_body

create_respx_mock_from_capture(
respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base],
capture_path=project_tests_dir / "mocks" / _capture,
side_effects_callbacks=[_default_side_effect] * 50,
)

mock_handler_in_functions_rpc_interface(
"get_function_user_permissions",
FunctionUserAccessRights(
user_id=user_id,
execute=True,
read=True,
write=True,
),
None,
None,
)
mock_handler_in_functions_rpc_interface(
"get_function", mock_registered_project_function, None, None
)
mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None)

_generated_function_job_ids: list[FunctionJobID] = []

async def _register_function_job_side_effect(
generated_function_job_ids: list[FunctionJobID], *args, **kwargs
):
uid = FunctionJobID(_faker.uuid4())
generated_function_job_ids.append(uid)
return mock_registered_project_function_job.model_copy(update={"uid": uid})

mock_handler_in_functions_rpc_interface(
"register_function_job",
None,
None,
partial(_register_function_job_side_effect, _generated_function_job_ids),
)
mock_handler_in_functions_rpc_interface(
"get_functions_user_api_access_rights",
FunctionUserApiAccessRights(
user_id=user_id,
execute_functions=True,
write_functions=True,
read_functions=True,
),
None,
None,
)

async def _register_function_job_collection_side_effect(*args, **kwargs):
job_collection = kwargs["function_job_collection"]
return RegisteredFunctionJobCollection(
uid=FunctionJobID(_faker.uuid4()),
title="Test Collection",
description="A test function job collection",
job_ids=job_collection.job_ids,
created_at=datetime.datetime.now(datetime.UTC),
)

mock_handler_in_functions_rpc_interface(
"register_function_job_collection",
None,
None,
_register_function_job_collection_side_effect,
)

patch_mock = mock_handler_in_functions_rpc_interface(
"patch_registered_function_job",
None,
None,
partial(
_patch_registered_function_job_side_effect,
mock_registered_project_function_job,
),
)

# act
_inputs = [{}, {}]
response = await client.post(
f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map",
json=_inputs,
auth=auth,
headers={
X_SIMCORE_PARENT_PROJECT_UUID: "null",
X_SIMCORE_PARENT_NODE_ID: "null",
},
)

# assert
assert response.status_code == status.HTTP_200_OK
job_collection = FunctionJobCollection.model_validate(response.json())
assert (
job_collection.job_ids == _generated_function_job_ids
), "Job ID did not preserve order or were incorrectly propagated"
celery_task_ids = {
elm.kwargs["registered_function_job_patch"].job_creation_task_id
for elm in patch_mock.call_args_list
}
assert len(celery_task_ids) == len(_inputs)
for task_id in celery_task_ids:
await wait_for_task_result(client, auth, f"{task_id}")
Loading