diff --git a/packages/celery-library/src/celery_library/backends/redis.py b/packages/celery-library/src/celery_library/backends/redis.py index d9d7c37a1685..cc19becbbcf5 100644 --- a/packages/celery-library/src/celery_library/backends/redis.py +++ b/packages/celery-library/src/celery_library/backends/redis.py @@ -6,12 +6,12 @@ from models_library.progress_bar import ProgressReport from pydantic import ValidationError from servicelib.celery.models import ( + WILDCARD, + ExecutionMetadata, + OwnerMetadata, Task, - TaskFilter, TaskID, TaskInfoStore, - TaskMetadata, - Wildcard, ) from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types @@ -35,7 +35,7 @@ def __init__(self, redis_client_sdk: RedisClientSDK) -> None: async def create_task( self, task_id: TaskID, - task_metadata: TaskMetadata, + execution_metadata: ExecutionMetadata, expiry: timedelta, ) -> None: task_key = _build_key(task_id) @@ -43,7 +43,7 @@ async def create_task( self._redis_client_sdk.redis.hset( name=task_key, key=_CELERY_TASK_METADATA_KEY, - value=task_metadata.model_dump_json(), + value=execution_metadata.model_dump_json(), ) ) await self._redis_client_sdk.redis.expire( @@ -51,7 +51,7 @@ async def create_task( expiry, ) - async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: + async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None: raw_result = await handle_redis_returns_union_types( self._redis_client_sdk.redis.hget( _build_key(task_id), _CELERY_TASK_METADATA_KEY @@ -61,7 +61,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: return None try: - return TaskMetadata.model_validate_json(raw_result) + return ExecutionMetadata.model_validate_json(raw_result) except ValidationError as exc: _logger.debug( "Failed to deserialize task metadata for task %s: %s", task_id, f"{exc}" @@ -85,9 +85,9 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ) return None - async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: - search_key = _CELERY_TASK_INFO_PREFIX + task_filter.create_task_id( - task_uuid=Wildcard() + async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: + search_key = _CELERY_TASK_INFO_PREFIX + owner_metadata.model_dump_task_id( + task_uuid=WILDCARD ) keys: list[str] = [] @@ -112,11 +112,11 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: continue with contextlib.suppress(ValidationError): - task_metadata = TaskMetadata.model_validate_json(raw_metadata) + execution_metadata = ExecutionMetadata.model_validate_json(raw_metadata) tasks.append( Task( - uuid=TaskFilter.get_task_uuid(key), - metadata=task_metadata, + uuid=OwnerMetadata.get_task_uuid(key), + metadata=execution_metadata, ) ) diff --git a/packages/celery-library/src/celery_library/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py index e0b077077f12..9af35a588d2c 100644 --- a/packages/celery-library/src/celery_library/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -4,7 +4,6 @@ from celery.exceptions import CeleryError # type: ignore[import-untyped] from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, AsyncJobId, AsyncJobResult, @@ -17,7 +16,7 @@ JobNotDoneError, JobSchedulerError, ) -from servicelib.celery.models import TaskFilter, TaskState +from servicelib.celery.models import OwnerMetadata, TaskState from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_catch from servicelib.rabbitmq import RPCRouter @@ -34,14 +33,13 @@ @router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError)) async def cancel( - task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter + task_manager: TaskManager, job_id: AsyncJobId, owner_metadata: OwnerMetadata ): assert task_manager # nosec - assert job_filter # nosec - task_filter = TaskFilter.model_validate(job_filter.model_dump()) + assert owner_metadata # nosec try: await task_manager.cancel_task( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=job_id, ) except TaskNotFoundError as exc: @@ -52,15 +50,14 @@ async def cancel( @router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError)) async def status( - task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter + task_manager: TaskManager, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobStatus: assert task_manager # nosec - assert job_filter # nosec + assert owner_metadata # nosec - task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: task_status = await task_manager.get_task_status( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=job_id, ) except TaskNotFoundError as exc: @@ -85,23 +82,21 @@ async def status( ) ) async def result( - task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter + task_manager: TaskManager, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobResult: assert task_manager # nosec assert job_id # nosec - assert job_filter # nosec - - task_filter = TaskFilter.model_validate(job_filter.model_dump()) + assert owner_metadata # nosec try: _status = await task_manager.get_task_status( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=job_id, ) if not _status.is_done: raise JobNotDoneError(job_id=job_id) _result = await task_manager.get_task_result( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=job_id, ) except TaskNotFoundError as exc: @@ -134,13 +129,12 @@ async def result( @router.expose(reraise_if_error_type=(JobSchedulerError,)) async def list_jobs( - task_manager: TaskManager, job_filter: AsyncJobFilter + task_manager: TaskManager, owner_metadata: OwnerMetadata ) -> list[AsyncJobGet]: assert task_manager # nosec - task_filter = TaskFilter.model_validate(job_filter.model_dump()) try: tasks = await task_manager.list_tasks( - task_filter=task_filter, + owner_metadata=owner_metadata, ) except CeleryError as exc: raise JobSchedulerError(exc=f"{exc}") from exc diff --git a/packages/celery-library/src/celery_library/task_manager.py b/packages/celery-library/src/celery_library/task_manager.py index 24c04ca95f3e..9697fd6e4fb3 100644 --- a/packages/celery-library/src/celery_library/task_manager.py +++ b/packages/celery-library/src/celery_library/task_manager.py @@ -9,11 +9,11 @@ from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( TASK_DONE_STATES, + ExecutionMetadata, + OwnerMetadata, Task, - TaskFilter, TaskID, TaskInfoStore, - TaskMetadata, TaskState, TaskStatus, TaskUUID, @@ -39,34 +39,34 @@ class CeleryTaskManager: async def submit_task( self, - task_metadata: TaskMetadata, + execution_metadata: ExecutionMetadata, *, - task_filter: TaskFilter, + owner_metadata: OwnerMetadata, **task_params, ) -> TaskUUID: with log_context( _logger, logging.DEBUG, - msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}", + msg=f"Submit {execution_metadata.name=}: {owner_metadata=} {task_params=}", ): task_uuid = uuid4() - task_id = task_filter.create_task_id(task_uuid=task_uuid) + task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid) expiry = ( self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES - if task_metadata.ephemeral + if execution_metadata.ephemeral else self._celery_settings.CELERY_RESULT_EXPIRES ) try: await self._task_info_store.create_task( - task_id, task_metadata, expiry=expiry + task_id, execution_metadata, expiry=expiry ) self._celery_app.send_task( - task_metadata.name, + execution_metadata.name, task_id=task_id, kwargs={"task_id": task_id} | task_params, - queue=task_metadata.queue.value, + queue=execution_metadata.queue.value, ) except CeleryError as exc: try: @@ -78,20 +78,22 @@ async def submit_task( exc_info=True, ) raise TaskSubmissionError( - task_name=task_metadata.name, + task_name=execution_metadata.name, task_id=task_id, task_params=task_params, ) from exc return task_uuid - async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None: + async def cancel_task( + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID + ) -> None: with log_context( _logger, logging.DEBUG, - msg=f"task cancellation: {task_filter=} {task_uuid=}", + msg=f"task cancellation: {owner_metadata=} {task_uuid=}", ): - task_id = task_filter.create_task_id(task_uuid=task_uuid) + task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) @@ -106,14 +108,14 @@ def _forget_task(self, task_id: TaskID) -> None: self._celery_app.AsyncResult(task_id).forget() async def get_task_result( - self, task_filter: TaskFilter, task_uuid: TaskUUID + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID ) -> Any: with log_context( _logger, logging.DEBUG, - msg=f"Get task result: {task_filter=} {task_uuid=}", + msg=f"Get task result: {owner_metadata=} {task_uuid=}", ): - task_id = task_filter.create_task_id(task_uuid=task_uuid) + task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) @@ -149,14 +151,14 @@ def _get_task_celery_state(self, task_id: TaskID) -> TaskState: return TaskState(self._celery_app.AsyncResult(task_id).state) async def get_task_status( - self, task_filter: TaskFilter, task_uuid: TaskUUID + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID ) -> TaskStatus: with log_context( _logger, logging.DEBUG, - msg=f"Getting task status: {task_filter=} {task_uuid=}", + msg=f"Getting task status: {owner_metadata=} {task_uuid=}", ): - task_id = task_filter.create_task_id(task_uuid=task_uuid) + task_id = owner_metadata.model_dump_task_id(task_uuid=task_uuid) if not await self.task_exists(task_id): raise TaskNotFoundError(task_id=task_id) @@ -169,13 +171,13 @@ async def get_task_status( ), ) - async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: + async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: with log_context( _logger, logging.DEBUG, - msg=f"Listing tasks: {task_filter=}", + msg=f"Listing tasks: {owner_metadata=}", ): - return await self._task_info_store.list_tasks(task_filter) + return await self._task_info_store.list_tasks(owner_metadata) async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: await self._task_info_store.set_task_progress( diff --git a/packages/celery-library/tests/unit/test_async_jobs.py b/packages/celery-library/tests/unit/test_async_jobs.py index 1622866f2424..4fddea2b698a 100644 --- a/packages/celery-library/tests/unit/test_async_jobs.py +++ b/packages/celery-library/tests/unit/test_async_jobs.py @@ -16,7 +16,6 @@ from common_library.errors_classes import OsparcErrorMixin from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_rpc_async_jobs.exceptions import ( @@ -27,7 +26,7 @@ from models_library.rabbitmq_basic_types import RPCNamespace from models_library.users import UserID from pydantic import TypeAdapter -from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata +from servicelib.celery.models import ExecutionMetadata, OwnerMetadata, TaskID from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RabbitMQRPCClient, RPCRouter from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs @@ -79,12 +78,11 @@ def product_name(faker: Faker) -> ProductName: @router.expose() async def rpc_sync_job( - task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any + task_manager: TaskManager, *, owner_metadata: OwnerMetadata, **kwargs: Any ) -> AsyncJobGet: task_name = sync_job.__name__ - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - TaskMetadata(name=task_name), task_filter=task_filter, **kwargs + ExecutionMetadata(name=task_name), owner_metadata=owner_metadata, **kwargs ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) @@ -92,12 +90,11 @@ async def rpc_sync_job( @router.expose() async def rpc_async_job( - task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any + task_manager: TaskManager, *, owner_metadata: OwnerMetadata, **kwargs: Any ) -> AsyncJobGet: task_name = async_job.__name__ - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - TaskMetadata(name=task_name), task_filter=task_filter, **kwargs + ExecutionMetadata(name=task_name), owner_metadata=owner_metadata, **kwargs ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) @@ -158,18 +155,18 @@ async def _start_task_via_rpc( user_id: UserID, product_name: ProductName, **kwargs: Any, -) -> tuple[AsyncJobGet, AsyncJobFilter]: - job_filter = AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name="pytest_client" +) -> tuple[AsyncJobGet, OwnerMetadata]: + owner_metadata = OwnerMetadata( + user_id=user_id, product_name=product_name, owner="pytest_client" ) async_job_get = await async_jobs.submit( rabbitmq_rpc_client=client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, method_name=rpc_task_name, - job_filter=job_filter, + owner_metadata=owner_metadata, **kwargs, ) - return async_job_get, job_filter + return async_job_get, owner_metadata @pytest.fixture @@ -197,7 +194,7 @@ async def _wait_for_job( rpc_client: RabbitMQRPCClient, *, async_job_get: AsyncJobGet, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, stop_after: timedelta = timedelta(seconds=5), ) -> None: @@ -212,7 +209,7 @@ async def _wait_for_job( rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) assert ( result.done is True @@ -246,7 +243,7 @@ async def test_async_jobs_workflow( exposed_rpc_start: str, payload: Any, ): - async_job_get, job_filter = await _start_task_via_rpc( + async_job_get, owner_metadata = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -258,21 +255,21 @@ async def test_async_jobs_workflow( jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, - job_filter=job_filter, + owner_metadata=owner_metadata, ) assert len(jobs) > 0 await _wait_for_job( async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, - job_filter=job_filter, + owner_metadata=owner_metadata, ) async_job_result = await async_jobs.result( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) assert async_job_result.result == payload @@ -291,7 +288,7 @@ async def test_async_jobs_cancel( product_name: ProductName, exposed_rpc_start: str, ): - async_job_get, job_filter = await _start_task_via_rpc( + async_job_get, owner_metadata = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -304,13 +301,13 @@ async def test_async_jobs_cancel( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) jobs = await async_jobs.list_jobs( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, - job_filter=job_filter, + owner_metadata=owner_metadata, ) assert async_job_get.job_id not in [job.job_id for job in jobs] @@ -319,7 +316,7 @@ async def test_async_jobs_cancel( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) with pytest.raises(JobMissingError): @@ -327,7 +324,7 @@ async def test_async_jobs_cancel( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @@ -357,7 +354,7 @@ async def test_async_jobs_raises( exposed_rpc_start: str, error: Exception, ): - async_job_get, job_filter = await _start_task_via_rpc( + async_job_get, owner_metadata = await _start_task_via_rpc( async_jobs_rabbitmq_rpc_client, rpc_task_name=exposed_rpc_start, user_id=user_id, @@ -369,7 +366,7 @@ async def test_async_jobs_raises( await _wait_for_job( async_jobs_rabbitmq_rpc_client, async_job_get=async_job_get, - job_filter=job_filter, + owner_metadata=owner_metadata, stop_after=timedelta(minutes=1), ) @@ -378,7 +375,7 @@ async def test_async_jobs_raises( async_jobs_rabbitmq_rpc_client, rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE, job_id=async_job_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) assert exc.value.exc_type == type(error).__name__ assert exc.value.exc_msg == f"{error}" diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index 757cbd3aaff2..9db28d7a5cab 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -21,9 +21,9 @@ from faker import Faker from models_library.progress_bar import ProgressReport from servicelib.celery.models import ( - TaskFilter, + ExecutionMetadata, + OwnerMetadata, TaskID, - TaskMetadata, TaskState, TaskUUID, Wildcard, @@ -39,7 +39,7 @@ pytest_simcore_ops_services_selection = [] -class MyTaskFilter(TaskFilter): +class MyOwnerMetadata(OwnerMetadata): user_id: int @@ -103,13 +103,14 @@ async def test_submitting_task_calling_async_function_results_with_success_state celery_task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - task_filter = MyTaskFilter(user_id=42) + + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=fake_file_processor.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, files=[f"file{n}" for n in range(5)], ) @@ -119,14 +120,16 @@ async def test_submitting_task_calling_async_function_results_with_success_state stop=stop_after_delay(30), ): with attempt: - status = await celery_task_manager.get_task_status(task_filter, task_uuid) + status = await celery_task_manager.get_task_status( + owner_metadata, task_uuid + ) assert status.task_state == TaskState.SUCCESS assert ( - await celery_task_manager.get_task_status(task_filter, task_uuid) + await celery_task_manager.get_task_status(owner_metadata, task_uuid) ).task_state == TaskState.SUCCESS assert ( - await celery_task_manager.get_task_result(task_filter, task_uuid) + await celery_task_manager.get_task_result(owner_metadata, task_uuid) ) == "archive.zip" @@ -134,13 +137,14 @@ async def test_submitting_task_with_failure_results_with_error( celery_task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - task_filter = MyTaskFilter(user_id=42) + + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=failure_task.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) for attempt in Retrying( @@ -151,11 +155,11 @@ async def test_submitting_task_with_failure_results_with_error( with attempt: raw_result = await celery_task_manager.get_task_result( - task_filter, task_uuid + owner_metadata, task_uuid ) assert isinstance(raw_result, TransferrableCeleryError) - raw_result = await celery_task_manager.get_task_result(task_filter, task_uuid) + raw_result = await celery_task_manager.get_task_result(owner_metadata, task_uuid) assert f"{raw_result}" == "Something strange happened: BOOM!" @@ -163,36 +167,38 @@ async def test_cancelling_a_running_task_aborts_and_deletes( celery_task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - task_filter = MyTaskFilter(user_id=42) + + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=dreamer_task.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) await asyncio.sleep(3.0) - await celery_task_manager.cancel_task(task_filter, task_uuid) + await celery_task_manager.cancel_task(owner_metadata, task_uuid) with pytest.raises(TaskNotFoundError): - await celery_task_manager.get_task_status(task_filter, task_uuid) + await celery_task_manager.get_task_status(owner_metadata, task_uuid) - assert task_uuid not in await celery_task_manager.list_tasks(task_filter) + assert task_uuid not in await celery_task_manager.list_tasks(owner_metadata) async def test_listing_task_uuids_contains_submitted_task( celery_task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - task_filter = MyTaskFilter(user_id=42) + + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=dreamer_task.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) for attempt in Retrying( @@ -201,64 +207,62 @@ async def test_listing_task_uuids_contains_submitted_task( stop=stop_after_delay(10), ): with attempt: - tasks = await celery_task_manager.list_tasks(task_filter) + tasks = await celery_task_manager.list_tasks(owner_metadata) assert any(task.uuid == task_uuid for task in tasks) - tasks = await celery_task_manager.list_tasks(task_filter) - assert any(task.uuid == task_uuid for task in tasks) + tasks = await celery_task_manager.list_tasks(owner_metadata) + assert any(task.uuid == task_uuid for task in tasks) async def test_filtering_listing_tasks( celery_task_manager: CeleryTaskManager, with_celery_worker: WorkController, ): - class MyFilter(TaskFilter): + class MyOwnerMetadata(OwnerMetadata): user_id: int product_name: str | Wildcard - client_app: str | Wildcard user_id = 42 + _owner = "test-owner" expected_task_uuids: set[TaskUUID] = set() - all_tasks: list[tuple[TaskUUID, MyFilter]] = [] + all_tasks: list[tuple[TaskUUID, MyOwnerMetadata]] = [] try: for _ in range(5): - task_filter = MyFilter( - user_id=user_id, - product_name=_faker.word(), - client_app=_faker.word(), + owner_metadata = MyOwnerMetadata( + user_id=user_id, product_name=_faker.word(), owner=_owner ) task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=dreamer_task.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) expected_task_uuids.add(task_uuid) - all_tasks.append((task_uuid, task_filter)) + all_tasks.append((task_uuid, owner_metadata)) for _ in range(3): - task_filter = MyFilter( + owner_metadata = MyOwnerMetadata( user_id=_faker.pyint(min_value=100, max_value=200), product_name=_faker.word(), - client_app=_faker.word(), + owner=_owner, ) task_uuid = await celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=dreamer_task.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) - all_tasks.append((task_uuid, task_filter)) + all_tasks.append((task_uuid, owner_metadata)) - search_filter = MyFilter( + search_owner_metadata = MyOwnerMetadata( user_id=user_id, - product_name=Wildcard(), - client_app=Wildcard(), + product_name="*", + owner=_owner, ) - tasks = await celery_task_manager.list_tasks(search_filter) + tasks = await celery_task_manager.list_tasks(search_owner_metadata) assert expected_task_uuids == {task.uuid for task in tasks} finally: # clean up all tasks. this should ideally be done in the fixture - for task_uuid, task_filter in all_tasks: - await celery_task_manager.cancel_task(task_filter, task_uuid) + for task_uuid, owner_metadata in all_tasks: + await celery_task_manager.cancel_task(owner_metadata, task_uuid) diff --git a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py index 6faa7b1f2b68..e71ee54bfaa8 100644 --- a/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py +++ b/packages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.py @@ -3,9 +3,7 @@ from pydantic import BaseModel, ConfigDict, StringConstraints -from ..products import ProductName from ..progress_bar import ProgressReport -from ..users import UserID AsyncJobId: TypeAlias = UUID AsyncJobName: TypeAlias = Annotated[ @@ -13,12 +11,6 @@ ] -class AsyncJobFilterBase(BaseModel): - """Base class for async job filters""" - - model_config = ConfigDict(extra="forbid") - - class AsyncJobStatus(BaseModel): job_id: AsyncJobId progress: ProgressReport @@ -48,26 +40,3 @@ class AsyncJobGet(BaseModel): class AsyncJobAbort(BaseModel): result: bool job_id: AsyncJobId - - -class AsyncJobFilter(AsyncJobFilterBase): - """Data for controlling access to an async job""" - - model_config = ConfigDict( - json_schema_extra={ - "examples": [ - { - "product_name": "osparc", - "user_id": 123, - "client_name": "web_client", - } - ] - }, - ) - - product_name: ProductName - user_id: UserID - client_name: Annotated[ # this is the name of the app which *submits* the async job. It is mainly used for filtering purposes - str, - StringConstraints(min_length=1, pattern=r"^[^\s]+$"), - ] diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py index f9e4193e7d86..eba9867b529d 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/async_jobs_server.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, AsyncJobId, AsyncJobResult, @@ -14,6 +13,7 @@ from models_library.rabbitmq_basic_types import RPCNamespace from pydantic import validate_call from pytest_mock import MockType +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient @@ -28,7 +28,7 @@ async def cancel( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> None: if self.exception is not None: raise self.exception @@ -41,7 +41,7 @@ async def status( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> AsyncJobStatus: if self.exception is not None: raise self.exception @@ -63,7 +63,7 @@ async def result( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> AsyncJobResult: if self.exception is not None: raise self.exception @@ -75,7 +75,7 @@ async def list_jobs( rabbitmq_rpc_client: RabbitMQRPCClient | MockType, *, rpc_namespace: RPCNamespace, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, filter_: str = "", ) -> list[AsyncJobGet]: if self.exception is not None: diff --git a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py index b4b3d771efc1..72dc62ca438f 100644 --- a/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py +++ b/packages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.py @@ -9,12 +9,13 @@ from typing import Literal from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_webserver.storage import PathToExport +from models_library.users import UserID from pydantic import TypeAdapter, validate_call from pytest_mock import MockType +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient @@ -27,18 +28,16 @@ async def start_export_data( *, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], - job_filter: AsyncJobFilter, - ) -> tuple[AsyncJobGet, AsyncJobFilter]: + owner_metadata: OwnerMetadata, + user_id: UserID + ) -> tuple[AsyncJobGet, OwnerMetadata]: assert rabbitmq_rpc_client - assert job_filter + assert owner_metadata assert paths_to_export assert export_as async_job_get = TypeAdapter(AsyncJobGet).validate_python( AsyncJobGet.model_json_schema()["examples"][0], ) - async_job_filter = TypeAdapter(AsyncJobFilter).validate_python( - AsyncJobFilter.model_json_schema()["examples"][0], - ) - return async_job_get, async_job_filter + return async_job_get, owner_metadata diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index 1090b3d50823..2f37e9b70c5d 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -1,10 +1,12 @@ import datetime from enum import StrEnum -from typing import Annotated, Any, Final, Protocol, Self, TypeAlias, TypeVar +from typing import Annotated, Final, Literal, Protocol, Self, TypeAlias, TypeVar from uuid import UUID +import orjson +from common_library.json_serialization import json_dumps, json_loads from models_library.progress_bar import ProgressReport -from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator +from pydantic import BaseModel, ConfigDict, Field, StringConstraints, model_validator from pydantic.config import JsonDict ModelType = TypeVar("ModelType", bound=BaseModel) @@ -15,88 +17,96 @@ ] TaskUUID: TypeAlias = UUID _TASK_ID_KEY_DELIMITATOR: Final[str] = ":" -_WILDCARD: Final[str] = "*" -_FORBIDDEN_CHARS = (_WILDCARD, _TASK_ID_KEY_DELIMITATOR, "=") - +_FORBIDDEN_KEYS = ("*", _TASK_ID_KEY_DELIMITATOR, "=") +_FORBIDDEN_VALUES = (_TASK_ID_KEY_DELIMITATOR, "=") +AllowedTypes = ( + int | float | bool | str | None | list[str] | list[int] | list[float] | list[bool] +) -class Wildcard: - def __str__(self) -> str: - return _WILDCARD +Wildcard: TypeAlias = Literal["*"] +WILDCARD: Final[Wildcard] = "*" -class TaskFilter(BaseModel): +class OwnerMetadata(BaseModel): """ - Class for associating metadata with a celery task. The implementation is very flexible and allows "clients" to define their own metadata. + Class for associating metadata with a celery task. The implementation is very flexible and allows the task owner to define their own metadata. + This could be metadata for validating if a user has access to a given task (e.g. user_id or product_name) or metadata for keeping track of how to handle a task, + e.g. which schema will the result of the task have. + The class exposes a filtering mechanism to list tasks using wildcards. Example usage: - class MyTaskFilter(TaskFilter): + class StorageOwnerMetadata(OwnerMetadata): user_id: int | Wildcard product_name: int | Wildcard - client_name: str + owner = APP_NAME - Listing tasks using the filter `MyTaskFilter(user_id=123, product_name=Wildcard(), client_name="my-app")` will return all tasks with - user_id 123, any product_name submitted from my-app. + Where APP_NAME is the name of the service. Listing tasks using the filter + `StorageOwnerMetadata(user_id=123, product_name=WILDCARD)` will return all tasks with + user_id 123, any product_name submitted from the service. If the metadata schema is known, the class allows deserializing the metadata (recreate_as_model). I.e. one can recover the metadata from the task: metadata -> task_uuid -> metadata """ - model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + model_config = ConfigDict(extra="allow", frozen=True) + owner: Annotated[ + str, + StringConstraints(min_length=1, pattern=r"^[a-z_-]+$"), + Field( + description='Identifies the service owning the task. Should be the "APP_NAME" of the service.' + ), + ] @model_validator(mode="after") def _check_valid_filters(self) -> Self: for key, value in self.model_dump().items(): # forbidden keys - if any(x in key for x in _FORBIDDEN_CHARS): + if any(x in key for x in _FORBIDDEN_KEYS): raise ValueError(f"Invalid filter key: '{key}'") # forbidden values - if not isinstance(value, Wildcard) and any( - x in f"{value}" for x in _FORBIDDEN_CHARS - ): + if any(x in f"{value}" for x in _FORBIDDEN_VALUES): raise ValueError(f"Invalid filter value for key '{key}': '{value}'") - return self - def _build_task_id_prefix(self) -> str: - filter_dict = self.model_dump() - return _TASK_ID_KEY_DELIMITATOR.join( - [f"{key}={filter_dict[key]}" for key in sorted(filter_dict)] - ) + class _TypeValidationModel(BaseModel): + filters: dict[str, AllowedTypes] + + _TypeValidationModel.model_validate({"filters": self.model_dump()}) + return self - def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID: + def model_dump_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID: + data = self.model_dump(mode="json") + data.update({"task_uuid": f"{task_uuid}"}) return _TASK_ID_KEY_DELIMITATOR.join( - [ - self._build_task_id_prefix(), - f"task_uuid={task_uuid}", - ] + [f"{k}={json_dumps(v)}" for k, v in sorted(data.items())] ) @classmethod - def recreate_as_model(cls, task_id: TaskID, schema: type[ModelType]) -> ModelType: - filter_dict = cls._recreate_data(task_id) - return schema.model_validate(filter_dict) + def model_validate_task_id(cls, task_id: TaskID) -> Self: + data = cls._deserialize_task_id(task_id) + data.pop("task_uuid", None) + return cls.model_validate(data) @classmethod - def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]: - """Recreates the filter data from a task_id string - WARNING: does not validate types. For that use `recreate_model` instead - """ + def _deserialize_task_id(cls, task_id: TaskID) -> dict[str, AllowedTypes]: + key_value_pairs = [ + item.split("=") for item in task_id.split(_TASK_ID_KEY_DELIMITATOR) + ] try: - parts = task_id.split(_TASK_ID_KEY_DELIMITATOR) - return { - key: value - for part in parts[:-1] - if (key := part.split("=")[0]) and (value := part.split("=")[1]) - } - except (IndexError, ValueError) as err: + return {key: json_loads(value) for key, value in key_value_pairs} + except orjson.JSONDecodeError as err: raise ValueError(f"Invalid task_id format: {task_id}") from err @classmethod def get_task_uuid(cls, task_id: TaskID) -> TaskUUID: + data = cls._deserialize_task_id(task_id) try: - return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1]) - except (IndexError, ValueError) as err: + uuid_string = data["task_uuid"] + if not isinstance(uuid_string, str): + raise ValueError(f"Invalid task_id format: {task_id}") + return TaskUUID(uuid_string) + except ValueError as err: raise ValueError(f"Invalid task_id format: {task_id}") from err @@ -120,7 +130,7 @@ class TasksQueue(StrEnum): API_WORKER_QUEUE = "api_worker_queue" -class TaskMetadata(BaseModel): +class ExecutionMetadata(BaseModel): name: TaskName ephemeral: bool = True queue: TasksQueue = TasksQueue.DEFAULT @@ -128,7 +138,7 @@ class TaskMetadata(BaseModel): class Task(BaseModel): uuid: TaskUUID - metadata: TaskMetadata + metadata: ExecutionMetadata @staticmethod def _update_json_schema_extra(schema: JsonDict) -> None: @@ -170,17 +180,17 @@ class TaskInfoStore(Protocol): async def create_task( self, task_id: TaskID, - task_metadata: TaskMetadata, + execution_metadata: ExecutionMetadata, expiry: datetime.timedelta, ) -> None: ... async def task_exists(self, task_id: TaskID) -> bool: ... - async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... + async def get_task_metadata(self, task_id: TaskID) -> ExecutionMetadata | None: ... async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ... - async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ... + async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ... async def remove_task(self, task_id: TaskID) -> None: ... diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 94c4019e0278..78722dd66454 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -3,10 +3,10 @@ from models_library.progress_bar import ProgressReport from ..celery.models import ( + ExecutionMetadata, + OwnerMetadata, Task, - TaskFilter, TaskID, - TaskMetadata, TaskStatus, TaskUUID, ) @@ -15,24 +15,28 @@ @runtime_checkable class TaskManager(Protocol): async def submit_task( - self, task_metadata: TaskMetadata, *, task_filter: TaskFilter, **task_param + self, + execution_metadata: ExecutionMetadata, + *, + owner_metadata: OwnerMetadata, + **task_param ) -> TaskUUID: ... async def cancel_task( - self, task_filter: TaskFilter, task_uuid: TaskUUID + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID ) -> None: ... async def task_exists(self, task_id: TaskID) -> bool: ... async def get_task_result( - self, task_filter: TaskFilter, task_uuid: TaskUUID + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID ) -> Any: ... async def get_task_status( - self, task_filter: TaskFilter, task_uuid: TaskUUID + self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID ) -> TaskStatus: ... - async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ... + async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ... async def set_task_progress( self, task_id: TaskID, report: ProgressReport diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py index 0c79097852fb..1d2da04185c8 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py @@ -6,7 +6,6 @@ from attr import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, AsyncJobId, AsyncJobResult, @@ -27,6 +26,7 @@ wait_random_exponential, ) +from ....celery.models import OwnerMetadata from ....rabbitmq import RemoteMethodNotRegisteredError from ... import RabbitMQRPCClient @@ -41,13 +41,13 @@ async def cancel( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> None: await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("cancel"), job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, timeout_s=_DEFAULT_TIMEOUT_S, ) @@ -57,13 +57,13 @@ async def status( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> AsyncJobStatus: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("status"), job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, timeout_s=_DEFAULT_TIMEOUT_S, ) assert isinstance(_result, AsyncJobStatus) @@ -75,13 +75,13 @@ async def result( *, rpc_namespace: RPCNamespace, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> AsyncJobResult: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("result"), job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, timeout_s=_DEFAULT_TIMEOUT_S, ) assert isinstance(_result, AsyncJobResult) @@ -92,12 +92,12 @@ async def list_jobs( rabbitmq_rpc_client: RabbitMQRPCClient, *, rpc_namespace: RPCNamespace, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ) -> list[AsyncJobGet]: _result: list[AsyncJobGet] = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python("list_jobs"), - job_filter=job_filter, + owner_metadata=owner_metadata, timeout_s=_DEFAULT_TIMEOUT_S, ) return _result @@ -108,13 +108,13 @@ async def submit( *, rpc_namespace: RPCNamespace, method_name: str, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, **kwargs, ) -> AsyncJobGet: _result = await rabbitmq_rpc_client.request( rpc_namespace, TypeAdapter(RPCMethodName).validate_python(method_name), - job_filter=job_filter, + owner_metadata=owner_metadata, **kwargs, timeout_s=_DEFAULT_TIMEOUT_S, ) @@ -138,7 +138,7 @@ async def _wait_for_completion( rpc_namespace: RPCNamespace, method_name: RPCMethodName, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, client_timeout: datetime.timedelta, ) -> AsyncGenerator[AsyncJobStatus, None]: try: @@ -154,7 +154,7 @@ async def _wait_for_completion( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) yield job_status if not job_status.done: @@ -189,7 +189,7 @@ async def wait_and_get_result( rpc_namespace: RPCNamespace, method_name: str, job_id: AsyncJobId, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, client_timeout: datetime.timedelta, ) -> AsyncGenerator[AsyncJobComposedResult, None]: """when a job is already submitted this will wait for its completion @@ -201,7 +201,7 @@ async def wait_and_get_result( rpc_namespace=rpc_namespace, method_name=method_name, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, client_timeout=client_timeout, ): assert job_status is not None # nosec @@ -215,7 +215,7 @@ async def wait_and_get_result( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ), ) except (TimeoutError, CancelledError) as error: @@ -224,7 +224,7 @@ async def wait_and_get_result( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) except Exception as exc: raise exc from error # NOSONAR @@ -236,7 +236,7 @@ async def submit_and_wait( *, rpc_namespace: RPCNamespace, method_name: str, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, client_timeout: datetime.timedelta, **kwargs, ) -> AsyncGenerator[AsyncJobComposedResult, None]: @@ -246,7 +246,7 @@ async def submit_and_wait( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, method_name=method_name, - job_filter=job_filter, + owner_metadata=owner_metadata, **kwargs, ) except (TimeoutError, CancelledError) as error: @@ -256,7 +256,7 @@ async def submit_and_wait( rabbitmq_rpc_client, rpc_namespace=rpc_namespace, job_id=async_job_rpc_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) except Exception as exc: raise exc from error @@ -267,7 +267,7 @@ async def submit_and_wait( rpc_namespace=rpc_namespace, method_name=method_name, job_id=async_job_rpc_get.job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, client_timeout=client_timeout, ): yield wait_and_ diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py index bd205479e5ab..c03be37d3937 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/paths.py @@ -1,14 +1,15 @@ from pathlib import Path from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.projects_nodes_io import LocationID from models_library.rabbitmq_basic_types import RPCMethodName +from models_library.users import UserID from pydantic import TypeAdapter +from ....celery.models import OwnerMetadata from ..._client_rpc import RabbitMQRPCClient from ..async_jobs.async_jobs import submit @@ -18,17 +19,19 @@ async def compute_path_size( *, location_id: LocationID, path: Path, - job_filter: AsyncJobFilter, -) -> tuple[AsyncJobGet, AsyncJobFilter]: + owner_metadata: OwnerMetadata, + user_id: UserID +) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=TypeAdapter(RPCMethodName).validate_python("compute_path_size"), - job_filter=job_filter, + owner_metadata=owner_metadata, location_id=location_id, path=path, + user_id=user_id, ) - return async_job_rpc_get, job_filter + return async_job_rpc_get, owner_metadata async def delete_paths( @@ -36,14 +39,16 @@ async def delete_paths( *, location_id: LocationID, paths: set[Path], - job_filter: AsyncJobFilter, -) -> tuple[AsyncJobGet, AsyncJobFilter]: + owner_metadata: OwnerMetadata, + user_id: UserID +) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=TypeAdapter(RPCMethodName).validate_python("delete_paths"), - job_filter=job_filter, + owner_metadata=owner_metadata, location_id=location_id, paths=paths, + user_id=user_id, ) - return async_job_rpc_get, job_filter + return async_job_rpc_get, owner_metadata diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py index 463b535667f0..31ca1d11440c 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py @@ -1,32 +1,38 @@ from typing import Literal from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE from models_library.api_schemas_storage.storage_schemas import FoldersBody from models_library.api_schemas_webserver.storage import PathToExport from models_library.rabbitmq_basic_types import RPCMethodName +from models_library.users import UserID from pydantic import TypeAdapter +from servicelib.celery.models import OwnerMetadata from ... import RabbitMQRPCClient from ..async_jobs.async_jobs import submit async def copy_folders_from_project( - client: RabbitMQRPCClient, *, body: FoldersBody, job_filter: AsyncJobFilter -) -> tuple[AsyncJobGet, AsyncJobFilter]: + client: RabbitMQRPCClient, + *, + body: FoldersBody, + owner_metadata: OwnerMetadata, + user_id: UserID +) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client=client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=TypeAdapter(RPCMethodName).validate_python( "copy_folders_from_project" ), - job_filter=job_filter, + owner_metadata=owner_metadata, body=body, + user_id=user_id, ) - return async_job_rpc_get, job_filter + return async_job_rpc_get, owner_metadata async def start_export_data( @@ -34,14 +40,16 @@ async def start_export_data( *, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], - job_filter: AsyncJobFilter -) -> tuple[AsyncJobGet, AsyncJobFilter]: + owner_metadata: OwnerMetadata, + user_id: UserID +) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=TypeAdapter(RPCMethodName).validate_python("start_export_data"), - job_filter=job_filter, + owner_metadata=owner_metadata, paths_to_export=paths_to_export, export_as=export_as, + user_id=user_id, ) - return async_job_rpc_get, job_filter + return async_job_rpc_get, owner_metadata diff --git a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py index 9105283f20ee..51874400b907 100644 --- a/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py +++ b/packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py @@ -8,16 +8,18 @@ from common_library.async_tools import cancel_wait_task from faker import Faker from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, AsyncJobId, AsyncJobResult, AsyncJobStatus, ) from models_library.api_schemas_rpc_async_jobs.exceptions import JobMissingError +from models_library.products import ProductName from models_library.progress_bar import ProgressReport from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace +from models_library.users import UserID from pydantic import TypeAdapter +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq import RabbitMQRPCClient, RemoteMethodNotRegisteredError from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( list_jobs, @@ -32,17 +34,23 @@ _ASYNC_JOB_CLIENT_NAME: Final[str] = "pytest_client_name" +class _TestOwnerMetadata(OwnerMetadata): + user_id: UserID + product_name: ProductName + owner: str = _ASYNC_JOB_CLIENT_NAME + + @pytest.fixture def method_name(faker: Faker) -> RPCMethodName: return TypeAdapter(RPCMethodName).validate_python(faker.word()) @pytest.fixture -def job_filter(faker: Faker) -> AsyncJobFilter: - return AsyncJobFilter( +def owner_metadata(faker: Faker) -> OwnerMetadata: + return _TestOwnerMetadata( user_id=faker.pyint(min_value=1), product_name=faker.word(), - client_name=_ASYNC_JOB_CLIENT_NAME, + owner=_ASYNC_JOB_CLIENT_NAME, ) @@ -72,9 +80,9 @@ def _get_task(self, job_id: AsyncJobId) -> asyncio.Task: raise JobMissingError(job_id=f"{job_id}") async def status( - self, job_id: AsyncJobId, job_filter: AsyncJobFilter + self, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobStatus: - assert job_filter + assert owner_metadata task = self._get_task(job_id) return AsyncJobStatus( job_id=job_id, @@ -82,28 +90,30 @@ async def status( done=task.done(), ) - async def cancel(self, job_id: AsyncJobId, job_filter: AsyncJobFilter) -> None: + async def cancel( + self, job_id: AsyncJobId, owner_metadata: OwnerMetadata + ) -> None: assert job_id - assert job_filter + assert owner_metadata task = self._get_task(job_id) task.cancel() async def result( - self, job_id: AsyncJobId, job_filter: AsyncJobFilter + self, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobResult: - assert job_filter + assert owner_metadata task = self._get_task(job_id) assert task.done() return AsyncJobResult( result={ "data": task.result(), "job_id": job_id, - "job_filter": job_filter, + "owner_metadata": owner_metadata, } ) - async def list_jobs(self, job_filter: AsyncJobFilter) -> list[AsyncJobGet]: - assert job_filter + async def list_jobs(self, owner_metadata: OwnerMetadata) -> list[AsyncJobGet]: + assert owner_metadata return [ AsyncJobGet( @@ -113,8 +123,8 @@ async def list_jobs(self, job_filter: AsyncJobFilter) -> list[AsyncJobGet]: for t in self.tasks ] - async def submit(self, job_filter: AsyncJobFilter) -> AsyncJobGet: - assert job_filter + async def submit(self, owner_metadata: OwnerMetadata) -> AsyncJobGet: + assert owner_metadata job_id = faker.uuid4(cast_to=None) self.tasks.append(asyncio.create_task(_slow_task(), name=f"{job_id}")) return AsyncJobGet(job_id=job_id, job_name="fake_job_name") @@ -144,7 +154,7 @@ async def test_async_jobs_methods( async_job_rpc_server: RabbitMQRPCClient, rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, job_id: AsyncJobId, method: str, ): @@ -156,7 +166,7 @@ async def test_async_jobs_methods( rpc_client, rpc_namespace=namespace, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @@ -165,12 +175,12 @@ async def test_list_jobs( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ): await list_jobs( rpc_client, rpc_namespace=namespace, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @@ -179,13 +189,13 @@ async def test_submit( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ): await submit( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @@ -193,14 +203,14 @@ async def test_submit_with_invalid_method_name( async_job_rpc_server: RabbitMQRPCClient, rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ): with pytest.raises(RemoteMethodNotRegisteredError): await submit( rpc_client, rpc_namespace=namespace, method_name=RPCMethodName("invalid_method_name"), - job_filter=job_filter, + owner_metadata=owner_metadata, ) @@ -209,14 +219,14 @@ async def test_submit_and_wait_properly_timesout( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ): with pytest.raises(TimeoutError): # noqa: PT012 async for _job_composed_result in submit_and_wait( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_filter=job_filter, + owner_metadata=owner_metadata, client_timeout=datetime.timedelta(seconds=0.1), ): pass @@ -227,13 +237,13 @@ async def test_submit_and_wait( rpc_client: RabbitMQRPCClient, namespace: RPCNamespace, method_name: RPCMethodName, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, ): async for job_composed_result in submit_and_wait( rpc_client, rpc_namespace=namespace, method_name=method_name, - job_filter=job_filter, + owner_metadata=owner_metadata, client_timeout=datetime.timedelta(seconds=10), ): if not job_composed_result.done: @@ -246,6 +256,6 @@ async def test_submit_and_wait( result={ "data": None, "job_id": job_composed_result.status.job_id, - "job_filter": job_filter, + "owner_metadata": owner_metadata, } ) diff --git a/packages/service-library/tests/test_celery.py b/packages/service-library/tests/test_celery.py index d9ef42f0e2e5..670805d1a2ef 100644 --- a/packages/service-library/tests/test_celery.py +++ b/packages/service-library/tests/test_celery.py @@ -1,69 +1,108 @@ +from types import NoneType +from typing import Annotated + # pylint: disable=redefined-outer-name # pylint: disable=protected-access import pydantic import pytest +from common_library.json_serialization import json_dumps from faker import Faker -from pydantic import BaseModel -from servicelib.celery.models import TaskFilter, TaskUUID +from pydantic import StringConstraints +from servicelib.celery.models import ( + OwnerMetadata, + TaskUUID, + Wildcard, +) _faker = Faker() +class _TestOwnerMetadata(OwnerMetadata): + string_: str + int_: int + bool_: bool + none_: None + uuid_: str + + @pytest.fixture -def task_filter_data() -> dict[str, str | int | bool | None | list[str]]: - return { - "string": _faker.word(), - "int": _faker.random_int(), - "bool": _faker.boolean(), - "none": None, - "uuid": _faker.uuid4(), - "list": [_faker.word() for _ in range(3)], +def test_owner_metadata() -> dict[str, str | int | bool | None | list[str]]: + data = { + "string_": _faker.word(), + "int_": _faker.random_int(), + "bool_": _faker.boolean(), + "none_": None, + "uuid_": _faker.uuid4(), + "owner": _faker.word().lower(), } + _TestOwnerMetadata.model_validate(data) # ensure it's valid + return data async def test_task_filter_serialization( - task_filter_data: dict[str, str | int | bool | None | list[str]], + test_owner_metadata: dict[str, str | int | bool | None | list[str]], ): - task_filter = TaskFilter.model_validate(task_filter_data) - assert task_filter.model_dump() == task_filter_data - assert task_filter.model_dump() == task_filter_data + task_filter = _TestOwnerMetadata.model_validate(test_owner_metadata) + assert task_filter.model_dump() == test_owner_metadata async def test_task_filter_sorting_key_not_serialized(): - keys = ["a", "b"] - task_filter = TaskFilter.model_validate( - { - "a": _faker.random_int(), - "b": _faker.word(), - } + class _OwnerMetadata(OwnerMetadata): + a: int | Wildcard + b: str | Wildcard + + owner_metadata = _OwnerMetadata.model_validate( + {"a": _faker.random_int(), "b": _faker.word(), "owner": _faker.word().lower()} + ) + task_uuid = TaskUUID(_faker.uuid4()) + copy_owner_metadata = owner_metadata.model_dump() + copy_owner_metadata.update({"task_uuid": f"{task_uuid}"}) + + expected_key = ":".join( + [f"{k}={json_dumps(v)}" for k, v in sorted(copy_owner_metadata.items())] ) - expected_key = ":".join([f"{k}={getattr(task_filter, k)}" for k in sorted(keys)]) - assert task_filter._build_task_id_prefix() == expected_key + assert owner_metadata.model_dump_task_id(task_uuid=task_uuid) == expected_key async def test_task_filter_task_uuid( - task_filter_data: dict[str, str | int | bool | None | list[str]], + test_owner_metadata: dict[str, str | int | bool | None | list[str]], ): - task_filter = TaskFilter.model_validate(task_filter_data) + task_filter = _TestOwnerMetadata.model_validate(test_owner_metadata) task_uuid = TaskUUID(_faker.uuid4()) - task_id = task_filter.create_task_id(task_uuid) - assert TaskFilter.get_task_uuid(task_id=task_id) == task_uuid - - -async def test_create_task_filter_from_task_id(): - - class MyModel(BaseModel): - _int: int - _bool: bool - _str: str - _list: list[str] - - mymodel = MyModel(_int=1, _bool=True, _str="test", _list=["a", "b"]) - task_filter = TaskFilter.model_validate(mymodel.model_dump()) + task_id = task_filter.model_dump_task_id(task_uuid) + assert OwnerMetadata.get_task_uuid(task_id=task_id) == task_uuid + + +async def test_owner_metadata_task_id_dump_and_validate(): + + class MyModel(OwnerMetadata): + int_: int + bool_: bool + str_: str + float_: float + none_: NoneType + list_s: list[str] + list_i: list[int] + list_f: list[float] + list_b: list[bool] + + mymodel = MyModel( + int_=1, + none_=None, + bool_=True, + str_="test", + float_=1.0, + owner="myowner", + list_b=[True, False], + list_f=[1.0, 2.0], + list_i=[1, 2], + list_s=["a", "b"], + ) task_uuid = TaskUUID(_faker.uuid4()) - task_id = task_filter.create_task_id(task_uuid) - assert TaskFilter.recreate_as_model(task_id=task_id, schema=MyModel) == mymodel + task_id = mymodel.model_dump_task_id(task_uuid) + mymodel_recreated = MyModel.model_validate_task_id(task_id=task_id) + assert mymodel_recreated == mymodel @pytest.mark.parametrize( @@ -79,4 +118,33 @@ class MyModel(BaseModel): ) def test_task_filter_validator_raises_on_forbidden_chars(bad_data): with pytest.raises(pydantic.ValidationError): - TaskFilter.model_validate(bad_data) + OwnerMetadata.model_validate(bad_data) + + +async def test_task_owner(): + class MyOwnerMetadata(OwnerMetadata): + extra_field: str + + with pytest.raises(pydantic.ValidationError): + MyOwnerMetadata(owner="", extra_field="value") + + with pytest.raises(pydantic.ValidationError): + MyOwnerMetadata(owner="UPPER_CASE", extra_field="value") + + class MyNextFilter(OwnerMetadata): + owner: Annotated[ + str, StringConstraints(strip_whitespace=True, pattern=r"^the_task_owner$") + ] + + with pytest.raises(pydantic.ValidationError): + MyNextFilter(owner="wrong_owner") + + +def test_owner_metadata_serialize_deserialize(test_owner_metadata): + test_owner_metadata = _TestOwnerMetadata.model_validate(test_owner_metadata) + data = test_owner_metadata.model_dump() + deserialized_data = OwnerMetadata.model_validate(data) + assert len(_TestOwnerMetadata.model_fields) > len( + OwnerMetadata.model_fields + ) # ensure extra data is available in _TestOwnerMetadata -> needed for RPC + assert deserialized_data.model_dump() == data diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py index fe05d663d959..7b35041aeb67 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py @@ -32,7 +32,7 @@ from models_library.rest_pagination import PageMetaInfoLimitOffset, PageOffsetInt from models_library.rpc_pagination import PageLimitInt from models_library.users import UserID -from servicelib.celery.models import TaskMetadata, TasksQueue, TaskUUID +from servicelib.celery.models import ExecutionMetadata, TasksQueue, TaskUUID from servicelib.celery.task_manager import TaskManager from simcore_service_api_server.models.schemas.functions import ( FunctionJobCreationTaskStatus, @@ -47,7 +47,7 @@ FunctionJobCacheNotFoundError, ) from .models.api_resources import JobLinks -from .models.domain.celery_models import ApiWorkerTaskFilter +from .models.domain.celery_models import ApiServerOwnerMetadata from .models.schemas.functions import FunctionJobCreationTaskStatus from .models.schemas.jobs import JobInputs, JobPricingSpecification from .services_http.webserver import AuthSession @@ -79,13 +79,13 @@ async def _celery_task_status( ) -> FunctionJobCreationTaskStatus: if job_creation_task_id is None: return FunctionJobCreationTaskStatus.NOT_YET_SCHEDULED - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_id, product_name=product_name, ) try: task_status = await task_manager.get_task_status( - task_uuid=TaskUUID(job_creation_task_id), task_filter=task_filter + task_uuid=TaskUUID(job_creation_task_id), owner_metadata=owner_metadata ) return FunctionJobCreationTaskStatus[task_status.task_state] except TaskNotFoundError as err: @@ -96,7 +96,7 @@ async def _celery_task_status( error=err, error_context={ "task_uuid": TaskUUID(job_creation_task_id), - "task_filter": task_filter, + "owner_metadata": owner_metadata, "user_id": user_id, "product_name": product_name, }, @@ -379,17 +379,18 @@ async def create_function_job_creation_task( ) # run function in celery task - task_filter = ApiWorkerTaskFilter( + + owner_metadata = ApiServerOwnerMetadata( user_id=user_identity.user_id, product_name=user_identity.product_name ) task_uuid = await self._celery_task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name="run_function", ephemeral=False, queue=TasksQueue.API_WORKER_QUEUE, ), - task_filter=task_filter, + owner_metadata=owner_metadata, user_identity=user_identity, function=function, pre_registered_function_job_data=pre_registered_function_job_data, diff --git a/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py index 86c37a0a7af6..5e6a05a48193 100644 --- a/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py +++ b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py @@ -1,8 +1,9 @@ from celery_library.task_manager import CeleryTaskManager from fastapi import FastAPI +from servicelib.celery.task_manager import TaskManager -def get_task_manager(app: FastAPI) -> CeleryTaskManager: +def get_task_manager(app: FastAPI) -> TaskManager: assert hasattr(app.state, "task_manager") # nosec task_manager = app.state.task_manager assert isinstance(task_manager, CeleryTaskManager) # nosec diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index 42199f8d4308..3ba23a481b16 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -17,8 +17,10 @@ from models_library.users import UserID from servicelib.celery.models import TaskState, TaskUUID from servicelib.fastapi.dependencies import get_app -from simcore_service_api_server.models.domain.celery_models import ApiWorkerTaskFilter +from ...models.domain.celery_models import ( + ApiServerOwnerMetadata, +) from ...models.schemas.base import ApiServerEnvelope from ...models.schemas.errors import ErrorGet from ..dependencies.authentication import get_current_user_id, get_product_name @@ -58,12 +60,12 @@ async def list_tasks( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_id, product_name=product_name, ) tasks = await task_manager.list_tasks( - task_filter=task_filter, + owner_metadata=owner_metadata, ) app_router = app.router @@ -103,12 +105,12 @@ async def get_task_status( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_id, product_name=product_name, ) task_status = await task_manager.get_task_status( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=TaskUUID(f"{task_id}"), ) @@ -141,12 +143,12 @@ async def cancel_task( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_id, product_name=product_name, ) await task_manager.cancel_task( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=TaskUUID(f"{task_id}"), ) @@ -176,13 +178,13 @@ async def get_task_result( product_name: Annotated[ProductName, Depends(get_product_name)], ): task_manager = get_task_manager(app) - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_id, product_name=product_name, ) task_status = await task_manager.get_task_status( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=TaskUUID(f"{task_id}"), ) @@ -193,7 +195,7 @@ async def get_task_result( ) task_result = await task_manager.get_task_result( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=TaskUUID(f"{task_id}"), ) diff --git a/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py b/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py index b5ed948b848f..520f97949ec0 100644 --- a/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py +++ b/services/api-server/src/simcore_service_api_server/models/domain/celery_models.py @@ -10,7 +10,7 @@ from models_library.products import ProductName from models_library.users import UserID from pydantic import Field, StringConstraints -from servicelib.celery.models import TaskFilter +from servicelib.celery.models import OwnerMetadata from ..._meta import APP_NAME from ...api.dependencies.authentication import Identity @@ -33,9 +33,9 @@ ) -class ApiWorkerTaskFilter(TaskFilter): +class ApiServerOwnerMetadata(OwnerMetadata): user_id: UserID product_name: ProductName - task_owner: Annotated[ + owner: Annotated[ str, StringConstraints(pattern=rf"^{APP_NAME}$"), Field(frozen=True) ] = APP_NAME diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py index 129a83577992..0c1cb911e2fd 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/async_jobs.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, AsyncJobId, AsyncJobResult, @@ -15,6 +14,7 @@ JobSchedulerError, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs @@ -40,12 +40,14 @@ class AsyncJobClient: JobSchedulerError: TaskSchedulerError, } ) - async def cancel(self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter) -> None: + async def cancel( + self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata + ) -> None: return await async_jobs.cancel( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @_exception_mapper( @@ -54,13 +56,13 @@ async def cancel(self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter) -> Non } ) async def status( - self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter + self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobStatus: return await async_jobs.status( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @_exception_mapper( @@ -72,13 +74,13 @@ async def status( } ) async def result( - self, *, job_id: AsyncJobId, job_filter: AsyncJobFilter + self, *, job_id: AsyncJobId, owner_metadata: OwnerMetadata ) -> AsyncJobResult: return await async_jobs.result( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=job_id, - job_filter=job_filter, + owner_metadata=owner_metadata, ) @_exception_mapper( @@ -86,9 +88,9 @@ async def result( JobSchedulerError: TaskSchedulerError, } ) - async def list_jobs(self, *, job_filter: AsyncJobFilter) -> list[AsyncJobGet]: + async def list_jobs(self, *, owner_metadata: OwnerMetadata) -> list[AsyncJobGet]: return await async_jobs.list_jobs( self._rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, - job_filter=job_filter, + owner_metadata=owner_metadata, ) diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py index 38b7abaa3b4a..94a82fd7c369 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/storage.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/storage.py @@ -2,27 +2,23 @@ from functools import partial from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_webserver.storage import PathToExport from models_library.products import ProductName from models_library.users import UserID +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.rabbitmq.rpc_interfaces.storage import simcore_s3 as storage_rpc -from .._meta import APP_NAME from ..exceptions.service_errors_utils import service_exception_mapper +from ..models.domain.celery_models import ( + ApiServerOwnerMetadata, +) _exception_mapper = partial(service_exception_mapper, service_name="Storage") -def get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: - return AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name=APP_NAME - ) - - @dataclass(frozen=True, kw_only=True) class StorageService: _rpc_client: RabbitMQRPCClient @@ -38,9 +34,11 @@ async def start_data_export( self._rpc_client, paths_to_export=paths_to_export, export_as="download_link", - job_filter=get_job_filter( - user_id=self._user_id, - product_name=self._product_name, + owner_metadata=OwnerMetadata.model_validate( + ApiServerOwnerMetadata( + user_id=self._user_id, product_name=self._product_name + ).model_dump() ), + user_id=self._user_id, ) return async_job_get diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index f8e782e0ac5f..2e105cd9c131 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -41,7 +41,7 @@ from models_library.users import UserID from pytest_mock import MockerFixture, MockType from pytest_simcore.helpers.httpx_calls_capture_models import HttpApiCallCaptureModel -from servicelib.celery.models import TaskID, TaskMetadata, TasksQueue +from servicelib.celery.models import ExecutionMetadata, TaskID, TasksQueue from servicelib.common_headers import ( X_SIMCORE_PARENT_NODE_ID, X_SIMCORE_PARENT_PROJECT_UUID, @@ -56,7 +56,9 @@ ) from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError from simcore_service_api_server.models.api_resources import JobLinks -from simcore_service_api_server.models.domain.celery_models import ApiWorkerTaskFilter +from simcore_service_api_server.models.domain.celery_models import ( + ApiServerOwnerMetadata, +) from simcore_service_api_server.models.domain.functions import ( PreRegisteredFunctionJobData, ) @@ -280,16 +282,16 @@ async def test_celery_error_propagation( with_api_server_celery_worker: TestWorkController, ): - task_filter = ApiWorkerTaskFilter( + owner_metadata = ApiServerOwnerMetadata( user_id=user_identity.user_id, product_name=user_identity.product_name, ) task_manager = get_task_manager(app=app) task_uuid = await task_manager.submit_task( - task_metadata=TaskMetadata( + execution_metadata=ExecutionMetadata( name="exception_task", queue=TasksQueue.API_WORKER_QUEUE ), - task_filter=task_filter, + owner_metadata=owner_metadata, ) with pytest.raises(HTTPStatusError) as exc_info: diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py index be67fd38fce8..5d25b8d8a407 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py @@ -33,7 +33,7 @@ from models_library.users import UserID from models_library.utils.json_schema import GenerateResolvedJsonSchema from pytest_mock import MockerFixture, MockType -from servicelib.celery.models import TaskFilter, TaskState, TaskStatus, TaskUUID +from servicelib.celery.models import OwnerMetadata, TaskState, TaskStatus, TaskUUID from simcore_service_api_server._meta import API_VTAG from simcore_service_api_server._service_function_jobs_task_client import ( FunctionJobTaskClientService, @@ -294,7 +294,7 @@ async def test_get_function_job_status( def _mock_task_manager(*args, **kwargs) -> CeleryTaskManager: async def _get_task_status( - task_uuid: TaskUUID, task_filter: TaskFilter + task_uuid: TaskUUID, owner_metadata: OwnerMetadata ) -> TaskStatus: assert f"{task_uuid}" == job_creation_task_id return TaskStatus( diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py index 29a9da86568a..66de98ec2d5f 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py @@ -9,6 +9,7 @@ FoldersBody, PresignedLink, ) +from servicelib.celery.models import OwnerMetadata from servicelib.logging_utils import log_context from ...models import FileMetaData @@ -26,7 +27,11 @@ def setup_worker_tasks(app: Celery) -> None: register_celery_types() register_pydantic_types( - FileUploadCompletionBody, FileMetaData, FoldersBody, PresignedLink + FileUploadCompletionBody, + FileMetaData, + FoldersBody, + PresignedLink, + OwnerMetadata, ) with log_context(_logger, logging.INFO, msg="worker task registration"): diff --git a/services/storage/src/simcore_service_storage/api/rest/_files.py b/services/storage/src/simcore_service_storage/api/rest/_files.py index d7530f514366..8bb3e8ff7562 100644 --- a/services/storage/src/simcore_service_storage/api/rest/_files.py +++ b/services/storage/src/simcore_service_storage/api/rest/_files.py @@ -20,7 +20,7 @@ from models_library.users import UserID from pydantic import AnyUrl, ByteSize, TypeAdapter from servicelib.aiohttp import status -from servicelib.celery.models import TaskFilter, TaskMetadata, TaskUUID +from servicelib.celery.models import ExecutionMetadata, OwnerMetadata, TaskUUID from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_context from yarl import URL @@ -43,12 +43,12 @@ from .dependencies.celery import get_task_manager -def _get_task_filter(*, user_id: UserID) -> TaskFilter: +def _get_owner_metadata(*, user_id: UserID) -> OwnerMetadata: _data = { + "owner": APP_NAME, "user_id": user_id, - "client_name": APP_NAME, } - return TaskFilter().model_validate(_data) + return OwnerMetadata.model_validate(_data) _logger = logging.getLogger(__name__) @@ -294,12 +294,13 @@ async def complete_upload_file( # NOTE: completing a multipart upload on AWS can take up to several minutes # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - task_filter = _get_task_filter(user_id=query_params.user_id) + + owner_metadata = _get_owner_metadata(user_id=query_params.user_id) task_uuid = await task_manager.submit_task( - TaskMetadata( + ExecutionMetadata( name=remote_complete_upload_file.__name__, ), - task_filter=task_filter, + owner_metadata=owner_metadata, user_id=query_params.user_id, location_id=location_id, file_id=file_id, @@ -347,15 +348,15 @@ async def is_completed_upload_file( # therefore we wait a bit to see if it completes fast and return a 204 # if it returns slow we return a 202 - Accepted, the client will have to check later # for completeness - task_filter = _get_task_filter(user_id=query_params.user_id) + owner_metadata = _get_owner_metadata(user_id=query_params.user_id) task_status = await task_manager.get_task_status( - task_filter=task_filter, task_uuid=TaskUUID(future_id) + owner_metadata=owner_metadata, task_uuid=TaskUUID(future_id) ) # first check if the task is in the app if task_status.is_done: task_result = TypeAdapter(FileMetaData).validate_python( await task_manager.get_task_result( - task_filter=task_filter, + owner_metadata=owner_metadata, task_uuid=TaskUUID(future_id), ) ) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_paths.py b/services/storage/src/simcore_service_storage/api/rpc/_paths.py index b34da2e7e7f8..31aaa1fdc67e 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_paths.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_paths.py @@ -2,11 +2,11 @@ from pathlib import Path from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.projects_nodes_io import LocationID -from servicelib.celery.models import TaskFilter, TaskMetadata +from models_library.users import UserID +from servicelib.celery.models import ExecutionMetadata, OwnerMetadata from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter @@ -20,18 +20,18 @@ @router.expose(reraise_if_error_type=None) async def compute_path_size( task_manager: TaskManager, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, location_id: LocationID, path: Path, + user_id: UserID, ) -> AsyncJobGet: task_name = remote_compute_path_size.__name__ - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - task_metadata=TaskMetadata( + execution_metadata=ExecutionMetadata( name=task_name, ), - task_filter=task_filter, - user_id=job_filter.user_id, + owner_metadata=owner_metadata, + user_id=user_id, location_id=location_id, path=path, ) @@ -42,18 +42,18 @@ async def compute_path_size( @router.expose(reraise_if_error_type=None) async def delete_paths( task_manager: TaskManager, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, location_id: LocationID, paths: set[Path], + user_id: UserID, ) -> AsyncJobGet: task_name = remote_delete_paths.__name__ - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - task_metadata=TaskMetadata( + execution_metadata=ExecutionMetadata( name=task_name, ), - task_filter=task_filter, - user_id=job_filter.user_id, + owner_metadata=owner_metadata, + user_id=user_id, location_id=location_id, paths=paths, ) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py index 314bad0e00b0..e4b2e4cee76c 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py @@ -1,12 +1,16 @@ from typing import Literal from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobGet, ) from models_library.api_schemas_storage.storage_schemas import FoldersBody from models_library.api_schemas_webserver.storage import PathToExport -from servicelib.celery.models import TaskFilter, TaskMetadata, TasksQueue +from models_library.users import UserID +from servicelib.celery.models import ( + ExecutionMetadata, + OwnerMetadata, + TasksQueue, +) from servicelib.celery.task_manager import TaskManager from servicelib.rabbitmq import RPCRouter @@ -22,17 +26,17 @@ @router.expose(reraise_if_error_type=None) async def copy_folders_from_project( task_manager: TaskManager, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, body: FoldersBody, + user_id: UserID, ) -> AsyncJobGet: task_name = deep_copy_files_from_project.__name__ - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - task_metadata=TaskMetadata( + execution_metadata=ExecutionMetadata( name=task_name, ), - task_filter=task_filter, - user_id=job_filter.user_id, + owner_metadata=owner_metadata, + user_id=user_id, body=body, ) @@ -42,9 +46,10 @@ async def copy_folders_from_project( @router.expose() async def start_export_data( task_manager: TaskManager, - job_filter: AsyncJobFilter, + owner_metadata: OwnerMetadata, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], + user_id: UserID, ) -> AsyncJobGet: if export_as == "path": task_name = export_data.__name__ @@ -52,15 +57,14 @@ async def start_export_data( task_name = export_data_as_download_link.__name__ else: raise ValueError(f"Invalid export_as value: {export_as}") - task_filter = TaskFilter.model_validate(job_filter.model_dump()) task_uuid = await task_manager.submit_task( - task_metadata=TaskMetadata( + execution_metadata=ExecutionMetadata( name=task_name, ephemeral=False, queue=TasksQueue.CPU_BOUND, ), - task_filter=task_filter, - user_id=job_filter.user_id, + owner_metadata=owner_metadata, + user_id=user_id, paths_to_export=paths_to_export, ) return AsyncJobGet(job_id=task_uuid, job_name=task_name) diff --git a/services/storage/tests/unit/test_rpc_handlers_paths.py b/services/storage/tests/unit/test_rpc_handlers_paths.py index c1acc0719f9b..8ea54e4c614d 100644 --- a/services/storage/tests/unit/test_rpc_handlers_paths.py +++ b/services/storage/tests/unit/test_rpc_handlers_paths.py @@ -17,7 +17,6 @@ from faker import Faker from fastapi import FastAPI from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobResult, ) from models_library.api_schemas_storage import STORAGE_RPC_NAMESPACE @@ -27,6 +26,7 @@ from models_library.users import UserID from pydantic import ByteSize, TypeAdapter from pytest_simcore.helpers.storage_utils import FileIDDict, ProjectWithFilesParams +from servicelib.celery.models import OwnerMetadata, Wildcard from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( wait_and_get_result, @@ -43,6 +43,11 @@ _IsFile: TypeAlias = bool +class TestOwnerMetadata(OwnerMetadata): + user_id: int | Wildcard + product_name: str | Wildcard + + def _filter_and_group_paths_one_level_deeper( paths: list[Path], prefix: Path ) -> list[tuple[Path, _IsFile]]: @@ -73,17 +78,18 @@ async def _assert_compute_path_size( storage_rpc_client, location_id=location_id, path=path, - job_filter=AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + owner_metadata=TestOwnerMetadata( + user_id=user_id, product_name=product_name, owner="pytest_client_name" ), + user_id=user_id, ) async for job_composed_result in wait_and_get_result( storage_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=RPCMethodName(compute_path_size.__name__), job_id=async_job.job_id, - job_filter=AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + owner_metadata=TestOwnerMetadata( + user_id=user_id, product_name=product_name, owner="pytest_client_name" ), client_timeout=datetime.timedelta(seconds=120), ): @@ -110,17 +116,18 @@ async def _assert_delete_paths( storage_rpc_client, location_id=location_id, paths=paths, - job_filter=AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + owner_metadata=TestOwnerMetadata( + user_id=user_id, product_name=product_name, owner="pytest_client_name" ), + user_id=user_id, ) async for job_composed_result in wait_and_get_result( storage_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=RPCMethodName(compute_path_size.__name__), job_id=async_job.job_id, - job_filter=AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name="PYTEST_CLIENT_NAME" + owner_metadata=TestOwnerMetadata( + user_id=user_id, product_name=product_name, owner="pytest_client_name" ), client_timeout=datetime.timedelta(seconds=120), ): diff --git a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py index 59d5a6d5586f..1ca1c5d3c729 100644 --- a/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py +++ b/services/storage/tests/unit/test_rpc_handlers_simcore_s3.py @@ -25,7 +25,6 @@ from fastapi import FastAPI from fastapi.encoders import jsonable_encoder from models_library.api_schemas_rpc_async_jobs.async_jobs import ( - AsyncJobFilter, AsyncJobResult, ) from models_library.api_schemas_rpc_async_jobs.exceptions import JobError @@ -55,6 +54,7 @@ ) from pytest_simcore.helpers.storage_utils_project import clone_project_data from servicelib.aiohttp import status +from servicelib.celery.models import OwnerMetadata from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from servicelib.rabbitmq._errors import RPCServerError from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import wait_and_get_result @@ -71,6 +71,12 @@ pytest_simcore_ops_services_selection = ["adminer"] +class _TestOwnerMetadata(OwnerMetadata): + user_id: UserID + product_name: ProductName + owner: str = "PYTEST_CLIENT_NAME" + + async def _request_copy_folders( rpc_client: RabbitMQRPCClient, user_id: UserID, @@ -85,16 +91,17 @@ async def _request_copy_folders( logging.INFO, f"Copying folders from {source_project['uuid']} to {dst_project['uuid']}", ) as ctx: - async_job_get, async_job_name = await copy_folders_from_project( + async_job_get, owner_metadata = await copy_folders_from_project( rpc_client, body=FoldersBody( source=source_project, destination=dst_project, nodes_map=nodes_map ), - job_filter=AsyncJobFilter( + owner_metadata=_TestOwnerMetadata( user_id=user_id, product_name=product_name, - client_name="PYTEST_CLIENT_NAME", + owner="PYTEST_CLIENT_NAME", ), + user_id=user_id, ) async for async_job_result in wait_and_get_result( @@ -102,7 +109,7 @@ async def _request_copy_folders( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=copy_folders_from_project.__name__, job_id=async_job_get.job_id, - job_filter=async_job_name, + owner_metadata=owner_metadata, client_timeout=client_timeout, ): ctx.logger.info("%s", f"<-- current state is {async_job_result=}") @@ -530,15 +537,16 @@ async def _request_start_export_data( logging.INFO, f"Data export form {paths_to_export=}", ) as ctx: - async_job_get, async_job_name = await start_export_data( + async_job_get, owner_metadata = await start_export_data( rpc_client, paths_to_export=paths_to_export, export_as=export_as, - job_filter=AsyncJobFilter( + owner_metadata=_TestOwnerMetadata( user_id=user_id, product_name=product_name, - client_name="PYTEST_CLIENT_NAME", + owner="PYTEST_CLIENT_NAME", ), + user_id=user_id, ) async for async_job_result in wait_and_get_result( @@ -546,7 +554,7 @@ async def _request_start_export_data( rpc_namespace=STORAGE_RPC_NAMESPACE, method_name=start_export_data.__name__, job_id=async_job_get.job_id, - job_filter=async_job_name, + owner_metadata=owner_metadata, client_timeout=client_timeout, ): ctx.logger.info("%s", f"<-- current state is {async_job_result=}") diff --git a/services/web/server/src/simcore_service_webserver/models.py b/services/web/server/src/simcore_service_webserver/models.py index ebdfdad0deda..cf795a6e252e 100644 --- a/services/web/server/src/simcore_service_webserver/models.py +++ b/services/web/server/src/simcore_service_webserver/models.py @@ -6,8 +6,10 @@ from pydantic import ConfigDict, Field, StringConstraints from pydantic_extra_types.phone_numbers import PhoneNumberValidator from servicelib.aiohttp.request_keys import RQT_USERID_KEY +from servicelib.celery.models import OwnerMetadata from servicelib.rest_constants import X_CLIENT_SESSION_ID_HEADER +from ._meta import APP_NAME from .constants import RQ_PRODUCT_KEY PhoneNumberStr: TypeAlias = Annotated[ @@ -55,3 +57,11 @@ class ClientSessionHeaderParams(RequestParameters): model_config = ConfigDict( validate_by_name=True, ) + + +class WebServerOwnerMetadata(OwnerMetadata): + user_id: UserID + product_name: ProductName + owner: Annotated[ + str, StringConstraints(pattern=rf"^{APP_NAME}$"), Field(frozen=True) + ] = APP_NAME diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index 8811ef81550f..e5cffbc94884 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -40,6 +40,7 @@ parse_request_query_parameters_as, ) from servicelib.aiohttp.rest_responses import create_data_response +from servicelib.celery.models import OwnerMetadata from servicelib.common_headers import X_FORWARDED_PROTO from servicelib.rabbitmq.rpc_interfaces.storage.paths import ( compute_path_size as remote_compute_path_size, @@ -53,11 +54,10 @@ from .._meta import API_VTAG from ..login.decorators import login_required -from ..models import AuthenticatedRequestContext +from ..models import AuthenticatedRequestContext, WebServerOwnerMetadata from ..rabbitmq import get_rabbitmq_rpc_client from ..security.decorators import permission_required from ..tasks._exception_handlers import handle_export_data_exceptions -from ..utils import get_job_filter from .schemas import StorageFileIDStr from .settings import StorageSettings, get_plugin_settings @@ -211,10 +211,13 @@ async def compute_path_size(request: web.Request) -> web.Response: rabbitmq_rpc_client, location_id=path_params.location_id, path=path_params.path, - job_filter=get_job_filter( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + ).model_dump() ), + user_id=req_ctx.user_id, ) return _create_data_response_from_async_job(request, async_job) @@ -236,10 +239,13 @@ async def batch_delete_paths(request: web.Request): rabbitmq_rpc_client, location_id=path_params.location_id, paths=body.paths, - job_filter=get_job_filter( - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + ).model_dump() ), + user_id=req_ctx.user_id, ) return _create_data_response_from_async_job(request, async_job) @@ -503,10 +509,13 @@ def allow_only_simcore(cls, v: int) -> int: rabbitmq_rpc_client=rabbitmq_rpc_client, paths_to_export=export_data_post.paths, export_as="path", - job_filter=get_job_filter( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() ), + user_id=_req_ctx.user_id, ) _job_id = f"{async_job_rpc_get.job_id}" return create_data_response( diff --git a/services/web/server/src/simcore_service_webserver/storage/api.py b/services/web/server/src/simcore_service_webserver/storage/api.py index d68f8bd94eea..34d86c1f4780 100644 --- a/services/web/server/src/simcore_service_webserver/storage/api.py +++ b/services/web/server/src/simcore_service_webserver/storage/api.py @@ -23,6 +23,7 @@ from models_library.users import UserID from pydantic import ByteSize, HttpUrl, TypeAdapter from servicelib.aiohttp.client_session import get_client_session +from servicelib.celery.models import OwnerMetadata from servicelib.logging_utils import log_context from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import ( AsyncJobComposedResult, @@ -30,10 +31,10 @@ ) from yarl import URL +from ..models import WebServerOwnerMetadata from ..projects.models import ProjectDict from ..projects.utils import NodesMap from ..rabbitmq import get_rabbitmq_rpc_client -from ..utils import get_job_filter from .settings import StorageSettings, get_plugin_settings _logger = logging.getLogger(__name__) @@ -119,9 +120,11 @@ async def copy_data_folders_from_project( rabbitmq_client, method_name="copy_folders_from_project", rpc_namespace=STORAGE_RPC_NAMESPACE, - job_filter=get_job_filter( - user_id=user_id, - product_name=product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=user_id, + product_name=product_name, + ).model_dump() ), body=TypeAdapter(FoldersBody).validate_python( { @@ -131,6 +134,7 @@ async def copy_data_folders_from_project( }, ), client_timeout=datetime.timedelta(seconds=_TOTAL_TIMEOUT_TO_COPY_DATA_SECS), + user_id=user_id, ): yield job_composed_result diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 0ba2b2d65ec5..3ab13730940e 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -27,16 +27,16 @@ parse_request_path_parameters_as, ) from servicelib.aiohttp.rest_responses import create_data_response +from servicelib.celery.models import OwnerMetadata from servicelib.long_running_tasks import lrt_api from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs from .._meta import API_VTAG from ..login.decorators import login_required from ..long_running_tasks.plugin import webserver_request_context_decorator -from ..models import AuthenticatedRequestContext +from ..models import AuthenticatedRequestContext, WebServerOwnerMetadata from ..rabbitmq import get_rabbitmq_rpc_client from ..security.decorators import permission_required -from ..utils import get_job_filter from ._exception_handlers import handle_export_data_exceptions log = logging.getLogger(__name__) @@ -70,9 +70,11 @@ async def get_async_jobs(request: web.Request) -> web.Response: user_async_jobs = await async_jobs.list_jobs( rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, - job_filter=get_job_filter( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() ), ) return create_data_response( @@ -119,9 +121,11 @@ async def get_async_job_status(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=get_job_filter( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() ), ) _task_id = f"{async_job_rpc_status.job_id}" @@ -155,9 +159,11 @@ async def cancel_async_job(request: web.Request) -> web.Response: rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=get_job_filter( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() ), ) @@ -183,9 +189,11 @@ class _PathParams(BaseModel): rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, job_id=async_job_get.task_id, - job_filter=get_job_filter( - user_id=_req_ctx.user_id, - product_name=_req_ctx.product_name, + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() ), ) diff --git a/services/web/server/src/simcore_service_webserver/utils.py b/services/web/server/src/simcore_service_webserver/utils.py index fee35eff5054..8928deb21c60 100644 --- a/services/web/server/src/simcore_service_webserver/utils.py +++ b/services/web/server/src/simcore_service_webserver/utils.py @@ -9,15 +9,10 @@ from datetime import datetime from common_library.error_codes import ErrorCodeStr -from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter -from models_library.products import ProductName -from models_library.users import UserID from typing_extensions import ( # https://docs.pydantic.dev/latest/api/standard_library_types/#typeddict TypedDict, ) -from ._meta import APP_NAME - _logger = logging.getLogger(__name__) DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -125,9 +120,3 @@ def compose_support_error_msg( ) return ". ".join(sentences) - - -def get_job_filter(*, user_id: UserID, product_name: ProductName) -> AsyncJobFilter: - return AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name=APP_NAME - )