diff --git a/api/specs/web-server/_long_running_tasks.py b/api/specs/web-server/_long_running_tasks.py index 6f43395da3f9..18785cb3dfc7 100644 --- a/api/specs/web-server/_long_running_tasks.py +++ b/api/specs/web-server/_long_running_tasks.py @@ -15,6 +15,10 @@ from simcore_service_webserver.tasks._controller._rest_exceptions import ( _TO_HTTP_ERROR_MAP, ) +from simcore_service_webserver.tasks._controller._rest_schemas import ( + TaskStreamQueryParams, + TaskStreamResponse, +) router = APIRouter( prefix=f"/{API_VTAG}", @@ -63,3 +67,14 @@ def get_async_job_result( _path_params: Annotated[_PathParam, Depends()], ): """Retrieves the result of a task""" + + +@router.get( + "/tasks/{task_id}/stream", + response_model=Envelope[TaskStreamResponse], +) +def get_async_job_stream( + _path_params: Annotated[_PathParam, Depends()], + _query_params: Annotated[TaskStreamQueryParams, Depends()], +): + """Retrieves the stream of a task""" diff --git a/api/specs/web-server/_storage.py b/api/specs/web-server/_storage.py index 1252e9420a81..37d1c8f3e00a 100644 --- a/api/specs/web-server/_storage.py +++ b/api/specs/web-server/_storage.py @@ -4,12 +4,10 @@ # pylint: disable=too-many-arguments -from typing import Annotated, TypeAlias +from typing import Annotated, Any, Final, TypeAlias from fastapi import APIRouter, Depends, Query, status -from models_library.api_schemas_long_running_tasks.tasks import ( - TaskGet, -) +from models_library.api_schemas_long_running_tasks.tasks import TaskGet from models_library.api_schemas_storage.storage_schemas import ( FileLocation, FileMetaDataGet, @@ -25,6 +23,7 @@ BatchDeletePathsBodyParams, DataExportPost, ListPathsQueryParams, + SearchBodyParams, StorageLocationPathParams, StoragePathComputeSizeParams, ) @@ -220,14 +219,33 @@ async def is_completed_upload_file( """Returns state of upload completion""" +_RESPONSES: Final[dict[int | str, dict[str, Any]]] = { + i.status_code: {"model": EnvelopedError} for i in _TO_HTTP_ERROR_MAP.values() +} + + @router.post( - "/storage/locations/{location_id}/export-data", + "/storage/locations/{location_id}:export-data", + status_code=status.HTTP_202_ACCEPTED, response_model=Envelope[TaskGet], name="export_data", description="Export data", - responses={ - i.status_code: {"model": EnvelopedError} for i in _TO_HTTP_ERROR_MAP.values() - }, + responses=_RESPONSES, ) async def export_data(export_data: DataExportPost, location_id: LocationID): """Trigger data export. Returns async job id for getting status and results""" + + +@router.post( + "/storage/locations/{location_id}:search", + status_code=status.HTTP_202_ACCEPTED, + response_model=Envelope[TaskGet], + name="search", + description="Starts a files/folders search", + responses=_RESPONSES, +) +async def search( + _path: Annotated[StorageLocationPathParams, Depends()], + _body: SearchBodyParams, +): + """Trigger search. Returns async job id for getting status and results""" diff --git a/packages/celery-library/src/celery_library/backends/redis.py b/packages/celery-library/src/celery_library/backends/redis.py index e7a99b85a057..6da1c1d432df 100644 --- a/packages/celery-library/src/celery_library/backends/redis.py +++ b/packages/celery-library/src/celery_library/backends/redis.py @@ -1,7 +1,7 @@ import contextlib import logging from dataclasses import dataclass -from datetime import timedelta +from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Final from models_library.progress_bar import ProgressReport @@ -13,21 +13,38 @@ Task, TaskKey, TaskStore, + TaskStreamItem, ) from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types +_CELERY_TASK_DELIMTATOR: Final[str] = ":" + _CELERY_TASK_PREFIX: Final[str] = "celery-task-" _CELERY_TASK_ID_KEY_ENCODING = "utf-8" _CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000 -_CELERY_TASK_METADATA_KEY: Final[str] = "metadata" +_CELERY_TASK_EXEC_METADATA_KEY: Final[str] = "exec-meta" _CELERY_TASK_PROGRESS_KEY: Final[str] = "progress" +### Redis list to store streamed results +_CELERY_TASK_STREAM_PREFIX: Final[str] = "celery-task-stream-" +_CELERY_TASK_STREAM_EXPIRY: Final[timedelta] = timedelta(minutes=3) +_CELERY_TASK_STREAM_METADATA: Final[str] = "meta" +_CELERY_TASK_STREAM_DONE_KEY: Final[str] = "done" +_CELERY_TASK_STREAM_LAST_UPDATE_KEY: Final[str] = "last_update" _logger = logging.getLogger(__name__) def _build_redis_task_key(task_key: TaskKey) -> str: - return _CELERY_TASK_PREFIX + task_key + return f"{_CELERY_TASK_PREFIX}{task_key}" + + +def _build_redis_stream_key(task_key: TaskKey) -> str: + return f"{_CELERY_TASK_STREAM_PREFIX}{task_key}" + + +def _build_redis_stream_meta_key(task_key: TaskKey) -> str: + return f"{_build_redis_stream_key(task_key)}{_CELERY_TASK_DELIMTATOR}{_CELERY_TASK_STREAM_METADATA}" @dataclass(frozen=True) @@ -44,7 +61,7 @@ async def create_task( await handle_redis_returns_union_types( self._redis_client_sdk.redis.hset( name=redis_key, - key=_CELERY_TASK_METADATA_KEY, + key=_CELERY_TASK_EXEC_METADATA_KEY, value=execution_metadata.model_dump_json(), ) ) @@ -57,7 +74,7 @@ async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None raw_result = await handle_redis_returns_union_types( self._redis_client_sdk.redis.hget( _build_redis_task_key(task_key), - _CELERY_TASK_METADATA_KEY, + _CELERY_TASK_EXEC_METADATA_KEY, ) ) if not raw_result: @@ -99,7 +116,7 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: ) keys: list[str] = [] - pipeline = self._redis_client_sdk.redis.pipeline() + pipe = self._redis_client_sdk.redis.pipeline() async for key in self._redis_client_sdk.redis.scan_iter( match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH ): @@ -110,9 +127,9 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: else key ) keys.append(_key) - pipeline.hget(_key, _CELERY_TASK_METADATA_KEY) + pipe.hget(_key, _CELERY_TASK_EXEC_METADATA_KEY) - results = await pipeline.execute() + results = await pipe.execute() tasks = [] for key, raw_metadata in zip(keys, results, strict=True): @@ -153,6 +170,62 @@ async def task_exists(self, task_key: TaskKey) -> bool: assert isinstance(n, int) # nosec return n > 0 + async def push_task_stream_items( + self, task_key: TaskKey, *result: TaskStreamItem + ) -> None: + stream_key = _build_redis_stream_key(task_key) + stream_meta_key = _build_redis_stream_meta_key(task_key) + + pipe = self._redis_client_sdk.redis.pipeline() + pipe.rpush(stream_key, *(r.model_dump_json(by_alias=True) for r in result)) + pipe.hset( + stream_meta_key, mapping={"last_update": datetime.now(UTC).isoformat()} + ) + pipe.expire(stream_key, _CELERY_TASK_STREAM_EXPIRY) + pipe.expire(stream_meta_key, _CELERY_TASK_STREAM_EXPIRY) + await pipe.execute() + + async def set_task_stream_done(self, task_key: TaskKey) -> None: + stream_meta_key = _build_redis_stream_meta_key(task_key) + await handle_redis_returns_union_types( + self._redis_client_sdk.redis.hset( + name=stream_meta_key, + key=_CELERY_TASK_STREAM_DONE_KEY, + value="1", + ) + ) + + async def pull_task_stream_items( + self, task_key: TaskKey, limit: int = 20 + ) -> tuple[list[TaskStreamItem], bool, datetime | None]: + stream_key = _build_redis_stream_key(task_key) + meta_key = _build_redis_stream_meta_key(task_key) + + async with self._redis_client_sdk.redis.pipeline(transaction=True) as pipe: + pipe.lpop(stream_key, limit) + pipe.hget(meta_key, _CELERY_TASK_STREAM_DONE_KEY) + pipe.hget(meta_key, _CELERY_TASK_STREAM_LAST_UPDATE_KEY) + raw_items, done, last_update = await pipe.execute() + + stream_items = ( + [TaskStreamItem.model_validate_json(item) for item in raw_items] + if raw_items + else [] + ) + + empty = ( + await handle_redis_returns_union_types( + self._redis_client_sdk.redis.llen(stream_key) + ) + == 0 + ) + + return ( + stream_items, + done == "1" and empty, + datetime.fromisoformat(last_update) if last_update else None, + ) + if TYPE_CHECKING: _: type[TaskStore] = RedisTaskStore diff --git a/packages/celery-library/src/celery_library/rpc/_async_jobs.py b/packages/celery-library/src/celery_library/rpc/_async_jobs.py index 9af35a588d2c..ff3d735d5b68 100644 --- a/packages/celery-library/src/celery_library/rpc/_async_jobs.py +++ b/packages/celery-library/src/celery_library/rpc/_async_jobs.py @@ -2,7 +2,6 @@ import logging -from celery.exceptions import CeleryError # type: ignore[import-untyped] from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, AsyncJobId, @@ -22,6 +21,7 @@ from servicelib.rabbitmq import RPCRouter from ..errors import ( + TaskManagerError, TaskNotFoundError, TransferrableCeleryError, decode_celery_transferrable_error, @@ -44,7 +44,7 @@ async def cancel( ) except TaskNotFoundError as exc: raise JobMissingError(job_id=job_id) from exc - except CeleryError as exc: + except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc @@ -62,7 +62,7 @@ async def status( ) except TaskNotFoundError as exc: raise JobMissingError(job_id=job_id) from exc - except CeleryError as exc: + except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc return AsyncJobStatus( @@ -101,7 +101,7 @@ async def result( ) except TaskNotFoundError as exc: raise JobMissingError(job_id=job_id) from exc - except CeleryError as exc: + except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc if _status.task_state == TaskState.FAILURE: @@ -136,7 +136,7 @@ async def list_jobs( tasks = await task_manager.list_tasks( owner_metadata=owner_metadata, ) - except CeleryError as exc: + except TaskManagerError as exc: raise JobSchedulerError(exc=f"{exc}") from exc return [ diff --git a/packages/celery-library/src/celery_library/task_manager.py b/packages/celery-library/src/celery_library/task_manager.py index ad6e15844cee..63b638a461c3 100644 --- a/packages/celery-library/src/celery_library/task_manager.py +++ b/packages/celery-library/src/celery_library/task_manager.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from datetime import datetime from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -16,6 +17,7 @@ TaskState, TaskStatus, TaskStore, + TaskStreamItem, TaskUUID, ) from servicelib.celery.task_manager import TaskManager @@ -35,7 +37,7 @@ class CeleryTaskManager: _celery_app: Celery _celery_settings: CelerySettings - _task_info_store: TaskStore + _task_store: TaskStore @handle_celery_errors async def submit_task( @@ -60,7 +62,7 @@ async def submit_task( ) try: - await self._task_info_store.create_task( + await self._task_store.create_task( task_key, execution_metadata, expiry=expiry ) self._celery_app.send_task( @@ -71,7 +73,7 @@ async def submit_task( ) except CeleryError as exc: try: - await self._task_info_store.remove_task(task_key) + await self._task_store.remove_task(task_key) except CeleryError: _logger.warning( "Unable to cleanup task '%s' during error handling", @@ -101,11 +103,11 @@ async def cancel_task( task_uuid=task_uuid, owner_metadata=owner_metadata ) - await self._task_info_store.remove_task(task_key) + await self._task_store.remove_task(task_key) await self._forget_task(task_key) async def task_exists(self, task_key: TaskKey) -> bool: - return await self._task_info_store.task_exists(task_key) + return await self._task_store.task_exists(task_key) @make_async() def _forget_task(self, task_key: TaskKey) -> None: @@ -129,9 +131,9 @@ async def get_task_result( async_result = self._celery_app.AsyncResult(task_key) result = async_result.result if async_result.ready(): - task_metadata = await self._task_info_store.get_task_metadata(task_key) + task_metadata = await self._task_store.get_task_metadata(task_key) if task_metadata is not None and task_metadata.ephemeral: - await self._task_info_store.remove_task(task_key) + await self._task_store.remove_task(task_key) await self._forget_task(task_key) return result @@ -139,7 +141,7 @@ async def _get_task_progress_report( self, task_key: TaskKey, task_state: TaskState ) -> ProgressReport: if task_state in (TaskState.STARTED, TaskState.RETRY): - progress = await self._task_info_store.get_task_progress(task_key) + progress = await self._task_store.get_task_progress(task_key) if progress is not None: return progress @@ -188,17 +190,62 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]: logging.DEBUG, msg=f"Listing tasks: {owner_metadata=}", ): - return await self._task_info_store.list_tasks(owner_metadata) + return await self._task_store.list_tasks(owner_metadata) @handle_celery_errors async def set_task_progress( self, task_key: TaskKey, report: ProgressReport ) -> None: - await self._task_info_store.set_task_progress( + await self._task_store.set_task_progress( task_key=task_key, report=report, ) + @handle_celery_errors + async def set_task_stream_done(self, task_key: TaskKey) -> None: + with log_context( + _logger, + logging.DEBUG, + msg=f"Set task stream done: {task_key=}", + ): + if not await self.task_exists(task_key): + raise TaskNotFoundError(task_key=task_key) + + await self._task_store.set_task_stream_done(task_key) + + @handle_celery_errors + async def push_task_stream_items( + self, task_key: TaskKey, *items: TaskStreamItem + ) -> None: + with log_context( + _logger, + logging.DEBUG, + msg=f"Push task stream items: {task_key=} {items=}", + ): + if not await self.task_exists(task_key): + raise TaskNotFoundError(task_key=task_key) + + await self._task_store.push_task_stream_items(task_key, *items) + + @handle_celery_errors + async def pull_task_stream_items( + self, + owner_metadata: OwnerMetadata, + task_uuid: TaskUUID, + offset: int = 0, + limit: int = 50, + ) -> tuple[list[TaskStreamItem], bool, datetime | None]: + with log_context( + _logger, + logging.DEBUG, + msg=f"Pull task results: {owner_metadata=} {task_uuid=} {offset=} {limit=}", + ): + task_key = owner_metadata.model_dump_task_key(task_uuid=task_uuid) + if not await self.task_exists(task_key): + raise TaskNotFoundError(task_key=task_key) + + return await self._task_store.pull_task_stream_items(task_key, limit) + if TYPE_CHECKING: _: type[TaskManager] = CeleryTaskManager diff --git a/packages/celery-library/tests/unit/test_task_manager.py b/packages/celery-library/tests/unit/test_task_manager.py index 44a4de374e7b..040c0541ed53 100644 --- a/packages/celery-library/tests/unit/test_task_manager.py +++ b/packages/celery-library/tests/unit/test_task_manager.py @@ -25,12 +25,18 @@ OwnerMetadata, TaskKey, TaskState, + TaskStreamItem, TaskUUID, Wildcard, ) from servicelib.celery.task_manager import TaskManager from servicelib.logging_utils import log_context -from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_delay, + wait_fixed, +) _faker = Faker() @@ -39,6 +45,13 @@ pytest_simcore_core_services_selection = ["redis"] pytest_simcore_ops_services_selection = [] +_TENACITY_RETRY_PARAMS = { + "reraise": True, + "retry": retry_if_exception_type(AssertionError), + "stop": stop_after_delay(30), + "wait": wait_fixed(0.1), +} + class MyOwnerMetadata(OwnerMetadata): user_id: int @@ -90,12 +103,40 @@ async def dreamer_task(task: Task, task_key: TaskKey) -> list[int]: return numbers +def streaming_results_task(task: Task, task_key: TaskKey, num_results: int = 5) -> str: + assert task_key + assert task.name + + async def _stream_results(sleep_interval: float) -> None: + app_server = get_app_server(task.app) + for i in range(num_results): + result_data = f"result-{i}" + result_item = TaskStreamItem(data=result_data) + await app_server.task_manager.push_task_stream_items( + task_key, + result_item, + ) + _logger.info("Pushed result %d: %s", i, result_data) + await asyncio.sleep(sleep_interval) + + # Mark the stream as done + await app_server.task_manager.set_task_stream_done(task_key) + + # Run the streaming in the event loop + asyncio.run_coroutine_threadsafe( + _stream_results(0.5), get_app_server(task.app).event_loop + ).result() + + return f"completed-{num_results}-results" + + @pytest.fixture def register_celery_tasks() -> Callable[[Celery], None]: def _(celery_app: Celery) -> None: register_task(celery_app, fake_file_processor) register_task(celery_app, failure_task) register_task(celery_app, dreamer_task) + register_task(celery_app, streaming_results_task) return _ @@ -115,11 +156,7 @@ async def test_submitting_task_calling_async_function_results_with_success_state files=[f"file{n}" for n in range(5)], ) - for attempt in Retrying( - retry=retry_if_exception_type(AssertionError), - wait=wait_fixed(1), - stop=stop_after_delay(30), - ): + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: status = await task_manager.get_task_status(owner_metadata, task_uuid) assert status.task_state == TaskState.SUCCESS @@ -146,12 +183,7 @@ async def test_submitting_task_with_failure_results_with_error( owner_metadata=owner_metadata, ) - for attempt in Retrying( - retry=retry_if_exception_type(AssertionError), - wait=wait_fixed(1), - stop=stop_after_delay(30), - ): - + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: raw_result = await task_manager.get_task_result(owner_metadata, task_uuid) assert isinstance(raw_result, TransferrableCeleryError) @@ -181,6 +213,8 @@ async def test_cancelling_a_running_task_aborts_and_deletes( with pytest.raises(TaskNotFoundError): await task_manager.get_task_status(owner_metadata, task_uuid) + tasks = await task_manager.list_tasks(owner_metadata) + assert task_uuid not in [task.uuid for task in tasks] assert task_uuid not in await task_manager.list_tasks(owner_metadata) @@ -198,11 +232,7 @@ async def test_listing_task_uuids_contains_submitted_task( owner_metadata=owner_metadata, ) - for attempt in Retrying( - retry=retry_if_exception_type(AssertionError), - wait=wait_fixed(0.1), - stop=stop_after_delay(10), - ): + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): with attempt: tasks = await task_manager.list_tasks(owner_metadata) assert any(task.uuid == task_uuid for task in tasks) @@ -263,3 +293,111 @@ class MyOwnerMetadata(OwnerMetadata): # clean up all tasks. this should ideally be done in the fixture for task_uuid, owner_metadata in all_tasks: await task_manager.cancel_task(owner_metadata, task_uuid) + + +async def test_push_task_result_streams_data_during_execution( + task_manager: CeleryTaskManager, + with_celery_worker: WorkController, +): + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") + + num_results = 3 + + task_uuid = await task_manager.submit_task( + ExecutionMetadata( + name=streaming_results_task.__name__, + ephemeral=False, # Keep task available after completion for result pulling + ), + owner_metadata=owner_metadata, + num_results=num_results, + ) + + # Pull results while task is running, retry until is_done is True + results = [] + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): + with attempt: + result, is_done, _ = await task_manager.pull_task_stream_items( + owner_metadata, task_uuid, limit=10 + ) + results.extend(result) + assert is_done + + # Should have at least some results streamed + assert results == [TaskStreamItem(data=f"result-{i}") for i in range(num_results)] + + # Wait for task completion + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): + with attempt: + status = await task_manager.get_task_status(owner_metadata, task_uuid) + assert status.task_state == TaskState.SUCCESS + + # Final task result should be available + final_result = await task_manager.get_task_result(owner_metadata, task_uuid) + assert final_result == f"completed-{num_results}-results" + + # After task completion, try to pull any remaining results + remaining_results, is_done, _ = await task_manager.pull_task_stream_items( + owner_metadata, task_uuid, limit=10 + ) + assert remaining_results == [] + assert is_done + + +async def test_pull_task_stream_items_with_limit( + task_manager: CeleryTaskManager, + with_celery_worker: WorkController, +): + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") + + # Submit task with fewer results to make it more predictable + task_uuid = await task_manager.submit_task( + ExecutionMetadata( + name=streaming_results_task.__name__, + ephemeral=False, # Keep task available after completion for result pulling + ), + owner_metadata=owner_metadata, + num_results=5, + ) + + # Wait for task to complete + async for attempt in AsyncRetrying(**_TENACITY_RETRY_PARAMS): + with attempt: + status = await task_manager.get_task_status(owner_metadata, task_uuid) + assert status.task_state == TaskState.SUCCESS + + # Pull all results in one go to avoid consumption issues + all_results, is_done_final, _last_update_final = ( + await task_manager.pull_task_stream_items( + owner_metadata, task_uuid, limit=20 # High limit to get all items + ) + ) + + assert all_results is not None + + assert len(all_results) == 5 # Can't have more than what was created + assert is_done_final + + # Verify result format for any results we got + for result in all_results: + assert result.data.startswith("result-") + + +async def test_pull_task_stream_items_from_nonexistent_task_raises_error( + task_manager: CeleryTaskManager, +): + owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner") + fake_task_uuid = TaskUUID(_faker.uuid4()) + + with pytest.raises(TaskNotFoundError): + await task_manager.pull_task_stream_items(owner_metadata, fake_task_uuid) + + +async def test_push_task_stream_items_to_nonexistent_task_raises_error( + task_manager: CeleryTaskManager, +): + not_existing_task_id = "not_existing" + + with pytest.raises(TaskNotFoundError): + await task_manager.push_task_stream_items( + not_existing_task_id, TaskStreamItem(data="some-result") + ) diff --git a/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py b/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py index 8f27127c71a0..b95816d04265 100644 --- a/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py +++ b/packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py @@ -1,6 +1,6 @@ import urllib.parse from datetime import datetime -from typing import Any +from typing import Any, Self from common_library.exclude import Unset from pydantic import BaseModel, ConfigDict, model_validator @@ -46,5 +46,13 @@ def try_populate_task_name_from_task_id(self) -> "TaskBase": class TaskGet(TaskBase): status_href: str - result_href: str abort_href: str + result_href: str | None = None + stream_href: str | None = None + + @model_validator(mode="after") + def _validate_result_hrefs(self) -> Self: + if self.result_href and self.stream_href: + msg = "Either result_href or stream_href must be set, or none of them" + raise ValueError(msg) + return self diff --git a/packages/models-library/src/models_library/api_schemas_storage/search_async_jobs.py b/packages/models-library/src/models_library/api_schemas_storage/search_async_jobs.py new file mode 100644 index 000000000000..5de7cf9d1c0e --- /dev/null +++ b/packages/models-library/src/models_library/api_schemas_storage/search_async_jobs.py @@ -0,0 +1,24 @@ +import datetime +from typing import Final, Literal + +from models_library.projects import ProjectID +from pydantic import BaseModel, ByteSize, ConfigDict +from pydantic.alias_generators import to_camel + +SEARCH_TASK_NAME: Final[str] = "files.search" + + +class SearchResultItem(BaseModel): + name: str + created_at: datetime.datetime + modified_at: datetime.datetime + size: ByteSize | Literal[-1] + path: str + is_directory: bool + project_id: ProjectID | None + + model_config = ConfigDict( + frozen=True, + alias_generator=to_camel, + validate_by_name=True, + ) diff --git a/packages/models-library/src/models_library/api_schemas_webserver/storage.py b/packages/models-library/src/models_library/api_schemas_webserver/storage.py index 8460493348cc..92a436724f6f 100644 --- a/packages/models-library/src/models_library/api_schemas_webserver/storage.py +++ b/packages/models-library/src/models_library/api_schemas_webserver/storage.py @@ -1,7 +1,14 @@ +import datetime from pathlib import Path -from typing import Annotated +from typing import Annotated, Self -from pydantic import BaseModel, Field +from models_library.projects import ProjectID +from models_library.utils.common_validators import ( + MIN_NON_WILDCARD_CHARS, + WILDCARD_CHARS, + ensure_pattern_has_enough_characters_before, +) +from pydantic import BaseModel, Field, model_validator from ..api_schemas_storage.storage_schemas import ( DEFAULT_NUMBER_OF_PATHS_PER_PAGE, @@ -42,3 +49,57 @@ class BatchDeletePathsBodyParams(InputSchema): class DataExportPost(InputSchema): paths: list[PathToExport] + + +class SearchTimerangeFilter(InputSchema): + from_: Annotated[ + datetime.datetime | None, + Field( + alias="from", + description="Filter results before this date", + ), + ] = None + until: Annotated[ + datetime.datetime | None, + Field( + description="Filter results after this date", + ), + ] = None + + @model_validator(mode="after") + def _validate_date_range(self) -> Self: + if ( + self.from_ is not None + and self.until is not None + and self.from_ > self.until + ): + msg = f"Invalid date range: '{self.from_}' must be before '{self.until}'" + raise ValueError(msg) + return self + + +class SearchFilters(InputSchema): + name_pattern: Annotated[ + str, + ensure_pattern_has_enough_characters_before(), + Field( + description=f"Name pattern with wildcard support ({', '.join(WILDCARD_CHARS)}). " + f"Minimum of {MIN_NON_WILDCARD_CHARS} non-wildcard characters required.", + ), + ] + modified_at: Annotated[ + SearchTimerangeFilter | None, + Field( + description="Filter results based on modification date range", + ), + ] = None + project_id: Annotated[ + ProjectID | None, + Field( + description="If provided, only files within this project are searched", + ), + ] = None + + +class SearchBodyParams(InputSchema): + filters: SearchFilters diff --git a/packages/models-library/src/models_library/rest_pagination_utils.py b/packages/models-library/src/models_library/rest_pagination_utils.py index 1bd952cfd127..d0956d29e9fc 100644 --- a/packages/models-library/src/models_library/rest_pagination_utils.py +++ b/packages/models-library/src/models_library/rest_pagination_utils.py @@ -19,16 +19,14 @@ @runtime_checkable class _YarlURL(Protocol): - def update_query(self, query) -> "_YarlURL": - ... + def update_query(self, query) -> "_YarlURL": ... class _StarletteURL(Protocol): # SEE starlette.data_structures.URL # in https://github.com/encode/starlette/blob/master/starlette/datastructures.py#L130 - def replace_query_params(self, **kwargs: Any) -> "_StarletteURL": - ... + def replace_query_params(self, **kwargs: Any) -> "_StarletteURL": ... _URLType = _YarlURL | _StarletteURL @@ -82,20 +80,53 @@ def paginate_data( _links=PageLinks( self=_replace_query(request_url, {"offset": offset, "limit": limit}), first=_replace_query(request_url, {"offset": 0, "limit": limit}), - prev=_replace_query( - request_url, {"offset": max(offset - limit, 0), "limit": limit} - ) - if offset > 0 - else None, - next=_replace_query( - request_url, - {"offset": min(offset + limit, last_page * limit), "limit": limit}, - ) - if offset < (last_page * limit) - else None, + prev=( + _replace_query( + request_url, {"offset": max(offset - limit, 0), "limit": limit} + ) + if offset > 0 + else None + ), + next=( + _replace_query( + request_url, + {"offset": min(offset + limit, last_page * limit), "limit": limit}, + ) + if offset < (last_page * limit) + else None + ), last=_replace_query( request_url, {"offset": last_page * limit, "limit": limit} ), ), data=data, ) + + +def paginate_stream_chunk( + chunk: list[Any], + *, + request_url: _URLType, + cursor: int, + has_more: bool, +) -> PageDict: + data = [ + item.model_dump() if hasattr(item, "model_dump") else item for item in chunk + ] + + return PageDict( + _meta={ + "cursor": cursor, + "count": len(chunk), + "has_more": has_more, + }, + _links={ + "self": _replace_query(request_url, {"cursor": cursor}), + "next": ( + _replace_query(request_url, {"cursor": cursor + len(chunk)}) + if has_more + else None + ), + }, + data=data, + ) diff --git a/packages/models-library/src/models_library/utils/common_validators.py b/packages/models-library/src/models_library/utils/common_validators.py index c55db09c5f52..422d30e97831 100644 --- a/packages/models-library/src/models_library/utils/common_validators.py +++ b/packages/models-library/src/models_library/utils/common_validators.py @@ -18,13 +18,16 @@ class MyModel(BaseModel): import enum import functools import operator -from typing import Any +from typing import Any, Final from common_library.json_serialization import json_loads from orjson import JSONDecodeError from pydantic import BaseModel, BeforeValidator from pydantic.alias_generators import to_camel +WILDCARD_CHARS: Final[list[str]] = ["*", "?"] +MIN_NON_WILDCARD_CHARS: Final[int] = 3 + def trim_string_before(max_length: int) -> BeforeValidator: def _trim(value: str): @@ -145,3 +148,20 @@ def to_camel_recursive(data: dict[str, Any]) -> dict[str, Any]: else: new_dict[new_key] = value return new_dict + + +def ensure_pattern_has_enough_characters_before( # pylint: disable=dangerous-default-value + min_non_wildcard_chars: int = MIN_NON_WILDCARD_CHARS, + wildcard_chars: list[str] | None = WILDCARD_CHARS, +) -> BeforeValidator: + def _validator(value): + assert wildcard_chars # nosec + + non_wildcard_chars = len([c for c in value if c not in wildcard_chars]) + + if non_wildcard_chars < min_non_wildcard_chars: + msg = f"Pattern '{value}' must contain at least {min_non_wildcard_chars} non-wildcard characters, got {non_wildcard_chars}" + raise ValueError(msg) + return value + + return BeforeValidator(_validator) diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index 68cdfc2e5811..570bd92aa422 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -1,6 +1,6 @@ -import datetime +from datetime import datetime, timedelta from enum import StrEnum -from typing import Annotated, Final, Literal, Protocol, Self, TypeAlias, TypeVar +from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar from uuid import UUID import orjson @@ -141,6 +141,10 @@ class ExecutionMetadata(BaseModel): queue: TasksQueue = TasksQueue.DEFAULT +class TaskStreamItem(BaseModel): + data: Any + + class Task(BaseModel): uuid: TaskUUID metadata: ExecutionMetadata @@ -186,7 +190,7 @@ async def create_task( self, task_key: TaskKey, execution_metadata: ExecutionMetadata, - expiry: datetime.timedelta, + expiry: timedelta, ) -> None: ... async def task_exists(self, task_key: TaskKey) -> bool: ... @@ -207,6 +211,16 @@ async def set_task_progress( report: ProgressReport, ) -> None: ... + async def set_task_stream_done(self, task_key: TaskKey) -> None: ... + + async def push_task_stream_items( + self, task_key: TaskKey, *item: TaskStreamItem + ) -> None: ... + + async def pull_task_stream_items( + self, task_key: TaskKey, limit: int + ) -> tuple[list[TaskStreamItem], bool, datetime | None]: ... + class TaskStatus(BaseModel): task_uuid: TaskUUID diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 5135f3168405..ebcfbb0522f6 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Protocol, runtime_checkable from models_library.progress_bar import ProgressReport @@ -8,6 +9,7 @@ Task, TaskKey, TaskStatus, + TaskStreamItem, TaskUUID, ) @@ -40,4 +42,18 @@ async def set_task_progress( self, task_key: TaskKey, report: ProgressReport ) -> None: ... + async def push_task_stream_items( + self, task_key: TaskKey, *items: TaskStreamItem + ) -> None: ... + + async def pull_task_stream_items( + self, + owner_metadata: OwnerMetadata, + task_uuid: TaskUUID, + offset: int = 0, + limit: int = 20, + ) -> tuple[list[TaskStreamItem], bool, datetime | None]: ... + + async def set_task_stream_done(self, task_key: TaskKey) -> None: ... + async def task_exists(self, task_key: TaskKey) -> bool: ... diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py index 31ca1d11440c..35f5ee9118d2 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/storage/simcore_s3.py @@ -20,7 +20,7 @@ async def copy_folders_from_project( *, body: FoldersBody, owner_metadata: OwnerMetadata, - user_id: UserID + user_id: UserID, ) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client=client, @@ -41,7 +41,7 @@ async def start_export_data( paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], owner_metadata: OwnerMetadata, - user_id: UserID + user_id: UserID, ) -> tuple[AsyncJobGet, OwnerMetadata]: async_job_rpc_get = await submit( rabbitmq_rpc_client, diff --git a/services/api-server/openapi.json b/services/api-server/openapi.json index 2bf3a2d9b70d..c6aa0416b5b1 100644 --- a/services/api-server/openapi.json +++ b/services/api-server/openapi.json @@ -12307,20 +12307,37 @@ "type": "string", "title": "Status Href" }, - "result_href": { - "type": "string", - "title": "Result Href" - }, "abort_href": { "type": "string", "title": "Abort Href" + }, + "result_href": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Result Href" + }, + "stream_href": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Stream Href" } }, "type": "object", "required": [ "task_id", "status_href", - "result_href", "abort_href" ], "title": "TaskGet" diff --git a/services/docker-compose.yml b/services/docker-compose.yml index 3fc59041e970..63c0ad5b48c6 100644 --- a/services/docker-compose.yml +++ b/services/docker-compose.yml @@ -860,7 +860,7 @@ services: PathRegexp(`^/v0/projects/[0-9a-fA-F-]+/nodes/[0-9a-fA-F-]+:close`) || PathRegexp(`^/v0/storage/locations/[0-9]+/paths/.+:size`) || PathRegexp(`^/v0/storage/locations/[0-9]+/-/paths:batchDelete`) || - PathRegexp(`^/v0/storage/locations/[0-9]+/export-data`) || + PathRegexp(`^/v0/storage/locations/[0-9]+:export-data`) || PathRegexp(`^/v0/tasks-legacy/.+`) # NOTE: the sticky router must have a higher priority than the webserver router but below dy-proxies - traefik.http.routers.${SWARM_STACK_NAME}_webserver_sticky.priority=8 diff --git a/services/static-webserver/client/source/class/osparc/data/Resources.js b/services/static-webserver/client/source/class/osparc/data/Resources.js index 60e1415d3f7d..aa2d1de06535 100644 --- a/services/static-webserver/client/source/class/osparc/data/Resources.js +++ b/services/static-webserver/client/source/class/osparc/data/Resources.js @@ -1375,7 +1375,7 @@ qx.Class.define("osparc.data.Resources", { }, multiDownload: { method: "POST", - url: statics.API + "/storage/locations/{locationId}/export-data" + url: statics.API + "/storage/locations/{locationId}:export-data" }, batchDelete: { method: "POST", diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py index 8094d748efee..d40a90b084c8 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py @@ -1,3 +1,4 @@ +import datetime import functools import logging from typing import Any @@ -5,6 +6,7 @@ from aws_library.s3._models import S3ObjectKey from celery import Task # type: ignore[import-untyped] from celery_library.utils import get_app_server +from models_library.api_schemas_storage.search_async_jobs import SearchResultItem from models_library.api_schemas_storage.storage_schemas import ( FoldersBody, LinkType, @@ -12,10 +14,14 @@ ) from models_library.api_schemas_webserver.storage import PathToExport from models_library.progress_bar import ProgressReport +from models_library.projects import ProjectID from models_library.projects_nodes_io import StorageFileID from models_library.users import UserID from pydantic import TypeAdapter -from servicelib.celery.models import TaskKey +from servicelib.celery.models import ( + TaskKey, + TaskStreamItem, +) from servicelib.logging_utils import log_context from servicelib.progress_bar import ProgressBarData @@ -128,3 +134,53 @@ async def export_data_as_download_link( user_id=user_id, file_id=s3_object, link_type=LinkType.PRESIGNED ) return PresignedLink(link=download_link) + + +async def search( + task: Task, + task_key: TaskKey, + *, + user_id: UserID, + project_id: ProjectID | None, + name_pattern: str, + modified_at: tuple[datetime.datetime | None, datetime.datetime | None] | None, +) -> None: + with log_context( + _logger, + logging.INFO, + f"'{task_key}' search file {name_pattern=}", + ): + app_server = get_app_server(task.app) + dsm = get_dsm_provider(app_server.app).get( + SimcoreS3DataManager.get_location_id() + ) + + assert isinstance(dsm, SimcoreS3DataManager) # nosec + + async for items in dsm.search( + user_id=user_id, + project_id=project_id, + name_pattern=name_pattern, + modified_at=modified_at, + ): + data = [ + TaskStreamItem( + data=SearchResultItem( + name=item.file_name, + project_id=item.project_id, + created_at=item.created_at, + modified_at=item.last_modified, + is_directory=item.is_directory, + size=item.file_size, + path=item.object_name, + ) + ) + for item in items + ] + + await app_server.task_manager.push_task_stream_items( + task_key, + *data, + ) + + await app_server.task_manager.set_task_stream_done(task_key) diff --git a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py index 66de98ec2d5f..5475b8eed8a4 100644 --- a/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py +++ b/services/storage/src/simcore_service_storage/api/_worker_tasks/tasks.py @@ -4,12 +4,12 @@ from celery_library.task import register_task from celery_library.types import register_celery_types, register_pydantic_types from models_library.api_schemas_storage.export_data_async_jobs import AccessRightError +from models_library.api_schemas_storage.search_async_jobs import SEARCH_TASK_NAME from models_library.api_schemas_storage.storage_schemas import ( FileUploadCompletionBody, FoldersBody, PresignedLink, ) -from servicelib.celery.models import OwnerMetadata from servicelib.logging_utils import log_context from ...models import FileMetaData @@ -19,6 +19,7 @@ deep_copy_files_from_project, export_data, export_data_as_download_link, + search, ) _logger = logging.getLogger(__name__) @@ -31,10 +32,8 @@ def setup_worker_tasks(app: Celery) -> None: FileMetaData, FoldersBody, PresignedLink, - OwnerMetadata, ) - - with log_context(_logger, logging.INFO, msg="worker task registration"): + with log_context(_logger, logging.INFO, msg="worker tasks registration"): register_task(app, export_data, dont_autoretry_for=(AccessRightError,)) register_task( app, export_data_as_download_link, dont_autoretry_for=(AccessRightError,) @@ -43,3 +42,4 @@ def setup_worker_tasks(app: Celery) -> None: register_task(app, complete_upload_file) register_task(app, delete_paths) register_task(app, deep_copy_files_from_project) + register_task(app, search, task_name=SEARCH_TASK_NAME) diff --git a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py index e4b2e4cee76c..885fce5cbd81 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py @@ -47,9 +47,9 @@ async def copy_folders_from_project( async def start_export_data( task_manager: TaskManager, owner_metadata: OwnerMetadata, + user_id: UserID, paths_to_export: list[PathToExport], export_as: Literal["path", "download_link"], - user_id: UserID, ) -> AsyncJobGet: if export_as == "path": task_name = export_data.__name__ diff --git a/services/storage/src/simcore_service_storage/modules/celery/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/__init__.py index 3f292337c73e..48e30e60a4f3 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/__init__.py +++ b/services/storage/src/simcore_service_storage/modules/celery/__init__.py @@ -14,8 +14,6 @@ from settings_library.celery import CelerySettings from settings_library.redis import RedisDatabase -from ...models import FileMetaData - _logger = logging.getLogger(__name__) @@ -38,7 +36,7 @@ async def on_startup() -> None: ) register_celery_types() - register_pydantic_types(FileUploadCompletionBody, FileMetaData, FoldersBody) + register_pydantic_types(FileUploadCompletionBody, FoldersBody) async def on_shutdown() -> None: with log_context(_logger, logging.INFO, "Shutting down Celery"): diff --git a/services/storage/src/simcore_service_storage/simcore_s3_dsm.py b/services/storage/src/simcore_service_storage/simcore_s3_dsm.py index 7e2dbdc5baf3..ff80405e1860 100644 --- a/services/storage/src/simcore_service_storage/simcore_s3_dsm.py +++ b/services/storage/src/simcore_service_storage/simcore_s3_dsm.py @@ -1,9 +1,10 @@ import contextlib import datetime +import fnmatch import logging import tempfile import urllib.parse -from collections.abc import Coroutine +from collections.abc import AsyncGenerator, Coroutine from contextlib import suppress from dataclasses import dataclass from pathlib import Path @@ -959,6 +960,155 @@ async def search_owned_files( resolved_fmds.append(convert_db_to_model(updated_fmd)) return resolved_fmds + async def _process_s3_page_results( + self, + current_page_results: list[FileMetaData], + ) -> list[FileMetaData]: + current_page_results.sort( + key=lambda x: x.last_modified + or datetime.datetime.min.replace(tzinfo=datetime.UTC), + reverse=True, + ) + + result_project_ids = list( + { + result.project_id + for result in current_page_results + if result.project_id is not None + } + ) + + if result_project_ids: + current_page_results = await _add_frontend_needed_data( + get_db_engine(self.app), + project_ids=result_project_ids, + data=current_page_results, + ) + + return current_page_results + + async def _search_project_s3_files( + self, + user_id: UserID, + proj_id: ProjectID, + name_pattern: str, + modified_at: ( + tuple[datetime.datetime | None, datetime.datetime | None] | None + ) = None, + ) -> AsyncGenerator[FileMetaData, None]: + """Search S3 files in a specific project and yield individual results.""" + s3_client = get_s3_client(self.app) + min_parts_for_valid_s3_object = 2 + + try: + async for s3_objects in s3_client.list_objects_paginated( + bucket=self.simcore_bucket_name, + prefix=f"{proj_id}/", + items_per_page=500, # fetch larger batches for efficiency + ): + for s3_obj in s3_objects: + filename = Path(s3_obj.object_key).name + + if not ( + fnmatch.fnmatch(filename, name_pattern) + and len(s3_obj.object_key.split("/")) + >= min_parts_for_valid_s3_object + ): + continue + + last_modified_from, last_modified_until = modified_at or ( + None, + None, + ) + if ( + last_modified_from + and s3_obj.last_modified + and s3_obj.last_modified < last_modified_from + ): + continue + + if ( + last_modified_until + and s3_obj.last_modified + and s3_obj.last_modified > last_modified_until + ): + continue + + file_meta = FileMetaData.from_simcore_node( + user_id=user_id, + file_id=TypeAdapter(SimcoreS3FileID).validate_python( + s3_obj.object_key + ), + bucket=self.simcore_bucket_name, + location_id=self.get_location_id(), + location_name=self.get_location_name(), + sha256_checksum=None, + file_size=s3_obj.size, + last_modified=s3_obj.last_modified, + entity_tag=s3_obj.e_tag, + ) + yield file_meta + + except S3KeyNotFoundError as exc: + _logger.debug("No files found in S3 for project %s: %s", proj_id, exc) + return + + async def search( + self, + user_id: UserID, + *, + name_pattern: str, + project_id: ProjectID | None = None, + modified_at: ( + tuple[datetime.datetime | None, datetime.datetime | None] | None + ) = None, + limit: NonNegativeInt = 100, + ) -> AsyncGenerator[list[FileMetaData], None]: + """ + Search for files in S3 using a wildcard pattern for filenames. + Returns results as an async generator that yields pages of results. + + Args: + user_id: The user requesting the search + name_pattern: Wildcard pattern for filename matching (e.g., "*.txt", "test_*.json") + project_id: Optional project ID to limit search to specific project + modified_before: Optional datetime filter - only include files modified before this datetime + modified_after: Optional datetime filter - only include files modified after this datetime + limit: Number of items to return per page + + Yields: + List of FileMetaData objects for each page, with exactly limit items + (except the last page which may have fewer) + """ + # Validate access rights + accessible_projects_ids = await get_accessible_project_ids( + get_db_engine(self.app), user_id=user_id, project_id=project_id + ) + + # Collect all results across projects + current_page_results: list[FileMetaData] = [] + + for proj_id in accessible_projects_ids: + async for file_result in self._search_project_s3_files( + user_id, proj_id, name_pattern, modified_at + ): + current_page_results.append(file_result) + + if len(current_page_results) >= limit: + page_batch = current_page_results[:limit] + remaining_results = current_page_results[limit:] + + processed_page = await self._process_s3_page_results(page_batch) + yield processed_page + + # NOTE: keep the remaining results for next page + current_page_results = remaining_results + + # Handle any remaining results (the last page) + if current_page_results: + processed_page = await self._process_s3_page_results(current_page_results) + yield processed_page + async def create_soft_link( self, user_id: int, target_file_id: StorageFileID, link_file_id: StorageFileID ) -> FileMetaData: diff --git a/services/storage/tests/unit/test_simcore_s3_dsm.py b/services/storage/tests/unit/test_simcore_s3_dsm.py index fdde44a86636..5fbe57f4aa31 100644 --- a/services/storage/tests/unit/test_simcore_s3_dsm.py +++ b/services/storage/tests/unit/test_simcore_s3_dsm.py @@ -168,6 +168,153 @@ async def test_upload_and_search( assert file.file_name in {"file1", "file2"} +async def _search_files_by_pattern( + simcore_s3_dsm: SimcoreS3DataManager, + user_id: UserID, + name_pattern: str, + project_id: ProjectID | None = None, + items_per_page: int = 10, +) -> list[FileMetaData]: + """Helper function to search files and collect all results.""" + results = [] + async for page in simcore_s3_dsm.search( + user_id=user_id, + name_pattern=name_pattern, + project_id=project_id, + limit=items_per_page, + ): + results.extend(page) + return results + + +@pytest.mark.parametrize( + "location_id", + [SimcoreS3DataManager.get_location_id()], + ids=[SimcoreS3DataManager.get_location_name()], + indirect=True, +) +async def test_search_files( + simcore_s3_dsm: SimcoreS3DataManager, + upload_file: Callable[..., Awaitable[tuple[Path, SimcoreS3FileID]]], + file_size: ByteSize, + user_id: UserID, + project_id: ProjectID, + faker: Faker, +): + # Upload files with different patterns + test_files = [ + ("test_file1.txt", "*.txt"), + ("test_file2.txt", "*.txt"), + ("document.pdf", "*.pdf"), + ("data_file.csv", "data_*.csv"), + ("backup_file.bak", "backup_*"), + ("config.json", "*.json"), + ("temp_data.tmp", "temp_*"), + ("file_a.log", "file_?.log"), + ("file_b.log", "file_?.log"), + ("file_10.log", "file_??.log"), + ("report1.txt", "report?.txt"), + ("report2.txt", "report?.txt"), + ] + + uploaded_files = [] + for file_name, _ in test_files: + checksum: SHA256Str = TypeAdapter(SHA256Str).validate_python(faker.sha256()) + _, file_id = await upload_file(file_size, file_name, sha256_checksum=checksum) + uploaded_files.append((file_name, file_id, checksum)) + + # Test 1: Search for all .txt files + txt_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "*.txt", project_id + ) + assert ( + len(txt_results) == 4 + ) # test_file1.txt, test_file2.txt, report1.txt, report2.txt + txt_names = {file.file_name for file in txt_results} + assert txt_names == { + "test_file1.txt", + "test_file2.txt", + "report1.txt", + "report2.txt", + } + + # Test 2: Search with specific prefix pattern + data_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "data_*", project_id + ) + assert len(data_results) == 1 + assert data_results[0].file_name == "data_file.csv" + + # Test 3: Search with pattern that matches multiple extensions + temp_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "temp_*", project_id + ) + assert len(temp_results) == 1 + assert temp_results[0].file_name == "temp_data.tmp" + + # Test 4: Search with pattern that doesn't match anything + no_match_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "nonexistent_*", project_id + ) + assert len(no_match_results) == 0 + + # Test 5: Search without project_id restriction (all accessible projects) + all_results = await _search_files_by_pattern(simcore_s3_dsm, user_id, "*") + assert len(all_results) >= len(test_files) + + # Verify that each result has expected FileMetaData structure + for file_meta in all_results: + assert isinstance(file_meta, FileMetaData) + assert file_meta.file_name is not None + assert file_meta.file_id is not None + assert file_meta.user_id == user_id + assert file_meta.project_id is not None + + # Test 6: Test ? wildcard - single character match + single_char_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "file_?.log", project_id + ) + # Should find 2 files: file_a.log and file_b.log (but not file_10.log) + assert len(single_char_results) == 2 + single_char_names = {file.file_name for file in single_char_results} + assert single_char_names == {"file_a.log", "file_b.log"} + + # Test 7: Test ?? wildcard - two character match + double_char_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "file_??.log", project_id + ) + # Should find 1 file: file_10.log + assert len(double_char_results) == 1 + assert double_char_results[0].file_name == "file_10.log" + + # Test 8: Test ? wildcard with specific prefix and suffix + report_results = await _search_files_by_pattern( + simcore_s3_dsm, user_id, "report?.txt", project_id + ) + # Should find 2 files: report1.txt and report2.txt + assert len(report_results) == 2 + report_names = {file.file_name for file in report_results} + assert report_names == {"report1.txt", "report2.txt"} + + # Test 9: Test pagination with small page size + paginated_results = [] + page_count = 0 + async for page in simcore_s3_dsm.search( + user_id=user_id, + name_pattern="*", + project_id=project_id, + limit=2, # Small page size to test pagination + ): + paginated_results.extend(page) + page_count += 1 + # Each page should have at most 2 items + assert len(page) <= 2 + + # Should have multiple pages and all our files + assert page_count >= 6 # At least 12 files / 2 per page = 6 pages + assert len(paginated_results) == len(test_files) + + @pytest.fixture async def paths_for_export( random_project_with_files: Callable[ diff --git a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml index 81fb46b3b4b1..ac535b341ab3 100644 --- a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml +++ b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml @@ -4046,6 +4046,60 @@ paths: schema: $ref: '#/components/schemas/EnvelopedError' description: Internal Server Error + /v0/tasks/{task_id}/stream: + get: + tags: + - long-running-tasks + summary: Get Async Job Stream + description: Retrieves the stream of a task + operationId: get_async_job_stream + parameters: + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: limit + in: query + required: false + schema: + type: integer + maximum: 50 + minimum: 1 + default: 20 + title: Limit + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Envelope_TaskStreamResponse_' + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Not Found + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Forbidden + '410': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Gone + '500': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Internal Server Error /v0/tasks-legacy: get: tags: @@ -7918,7 +7972,7 @@ paths: application/json: schema: $ref: '#/components/schemas/Envelope_FileUploadCompleteFutureResponse_' - /v0/storage/locations/{location_id}/export-data: + /v0/storage/locations/{location_id}:export-data: post: tags: - storage @@ -7939,7 +7993,58 @@ paths: schema: $ref: '#/components/schemas/DataExportPost' responses: - '200': + '202': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Envelope_TaskGet_' + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Not Found + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Forbidden + '410': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Gone + '500': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvelopedError' + description: Internal Server Error + /v0/storage/locations/{location_id}:search: + post: + tags: + - storage + summary: Search + description: Starts a files/folders search + operationId: search + parameters: + - name: location_id + in: path + required: true + schema: + type: integer + title: Location Id + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SearchBodyParams' + responses: + '202': description: Successful Response content: application/json: @@ -11501,6 +11606,19 @@ components: title: Error type: object title: Envelope[TaskStatus] + Envelope_TaskStreamResponse_: + properties: + data: + anyOf: + - $ref: '#/components/schemas/TaskStreamResponse' + - type: 'null' + error: + anyOf: + - {} + - type: 'null' + title: Error + type: object + title: Envelope[TaskStreamResponse] Envelope_Union_EmailTestFailed__EmailTestPassed__: properties: data: @@ -17091,6 +17209,55 @@ components: \ - A worker has picked up the task and is executing it\n- SUCCESS - Task\ \ finished successfully\n- FAILED - Task finished with an error\n- ABORTED\ \ - Task was aborted before completion" + SearchBodyParams: + properties: + filters: + $ref: '#/components/schemas/SearchFilters' + type: object + required: + - filters + title: SearchBodyParams + SearchFilters: + properties: + namePattern: + type: string + title: Namepattern + description: Name pattern with wildcard support (*, ?). Minimum of 3 non-wildcard + characters required. + modifiedAt: + anyOf: + - $ref: '#/components/schemas/SearchTimerangeFilter' + - type: 'null' + description: Filter results based on modification date range + projectId: + anyOf: + - type: string + format: uuid + - type: 'null' + title: Projectid + description: If provided, only files within this project are searched + type: object + required: + - namePattern + title: SearchFilters + SearchTimerangeFilter: + properties: + from: + anyOf: + - type: string + format: date-time + - type: 'null' + title: From + description: Filter results before this date + until: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Until + description: Filter results after this date + type: object + title: SearchTimerangeFilter SelectBox: properties: structure: @@ -18053,17 +18220,23 @@ components: status_href: type: string title: Status Href - result_href: - type: string - title: Result Href abort_href: type: string title: Abort Href + result_href: + anyOf: + - type: string + - type: 'null' + title: Result Href + stream_href: + anyOf: + - type: string + - type: 'null' + title: Stream Href type: object required: - task_id - status_href - - result_href - abort_href title: TaskGet TaskInfoDict: @@ -18141,6 +18314,23 @@ components: - done - started title: TaskStatus + TaskStreamResponse: + properties: + items: + items: + additionalProperties: true + type: object + type: array + title: Items + end: + type: boolean + title: End + additionalProperties: false + type: object + required: + - items + - end + title: TaskStreamResponse TestEmail: properties: from_: diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index 116ed1f5484a..86a5262bd323 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -5,7 +5,7 @@ import logging import urllib.parse -from typing import Any, Final, NamedTuple +from typing import Annotated, Any, Final, NamedTuple from urllib.parse import quote, unquote from aiohttp import ClientTimeout, web @@ -15,6 +15,7 @@ from models_library.api_schemas_rpc_async_jobs.async_jobs import ( AsyncJobGet, ) +from models_library.api_schemas_storage.search_async_jobs import SEARCH_TASK_NAME from models_library.api_schemas_storage.storage_schemas import ( FileUploadCompleteResponse, FileUploadCompletionBody, @@ -24,13 +25,20 @@ from models_library.api_schemas_webserver.storage import ( BatchDeletePathsBodyParams, DataExportPost, + SearchBodyParams, StorageLocationPathParams, StoragePathComputeSizeParams, ) from models_library.projects_nodes_io import LocationID from models_library.utils.change_case import camel_to_snake from models_library.utils.fastapi_encoders import jsonable_encoder -from pydantic import AnyUrl, BaseModel, ByteSize, TypeAdapter, field_validator +from pydantic import ( + AfterValidator, + AnyUrl, + BaseModel, + ByteSize, + TypeAdapter, +) from servicelib.aiohttp import status from servicelib.aiohttp.client_session import get_client_session from servicelib.aiohttp.request_keys import RQT_USERID_KEY @@ -40,7 +48,7 @@ parse_request_query_parameters_as, ) from servicelib.aiohttp.rest_responses import create_data_response -from servicelib.celery.models import OwnerMetadata +from servicelib.celery.models import ExecutionMetadata, OwnerMetadata from servicelib.common_headers import X_FORWARDED_PROTO from servicelib.rabbitmq.rpc_interfaces.storage.paths import ( compute_path_size as remote_compute_path_size, @@ -53,6 +61,7 @@ from yarl import URL from .._meta import API_VTAG +from ..celery import get_task_manager from ..login.decorators import login_required from ..models import AuthenticatedRequestContext, WebServerOwnerMetadata from ..rabbitmq import get_rabbitmq_rpc_client @@ -482,23 +491,24 @@ class _PathParams(BaseModel): return create_data_response(payload, status=resp_status) +def _allow_only_simcore(v: int) -> int: + if v != 0: + msg = ( + f"Only simcore (location_id='0'), provided location_id='{v}' is not allowed" + ) + raise ValueError(msg) + return v + + @routes.post( - _storage_locations_prefix + "/{location_id}/export-data", name="export_data" + _storage_locations_prefix + "/{location_id}:export-data", name="export_data" ) @login_required @permission_required("storage.files.*") @handle_rest_requests_exceptions async def export_data(request: web.Request) -> web.Response: class _PathParams(BaseModel): - location_id: LocationID - - @field_validator("location_id") - @classmethod - def allow_only_simcore(cls, v: int) -> int: - if v != 0: - msg = f"Only simcore (location_id='0'), provided location_id='{v}' is not allowed" - raise ValueError(msg) - return v + location_id: Annotated[LocationID, AfterValidator(_allow_only_simcore)] rabbitmq_rpc_client = get_rabbitmq_rpc_client(request.app) _req_ctx = AuthenticatedRequestContext.model_validate(request) @@ -522,9 +532,60 @@ def allow_only_simcore(cls, v: int) -> int: return create_data_response( TaskGet( task_id=_job_id, + task_name=async_job_rpc_get.job_name, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=_job_id)))}", abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=_job_id)))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=_job_id)))}", ), status=status.HTTP_202_ACCEPTED, ) + + +@routes.post(_storage_locations_prefix + "/{location_id}:search", name="search") +@login_required +@permission_required("storage.files.*") +@handle_rest_requests_exceptions +async def search(request: web.Request) -> web.Response: + class _PathParams(BaseModel): + location_id: Annotated[LocationID, AfterValidator(_allow_only_simcore)] + + _req_ctx = AuthenticatedRequestContext.model_validate(request) + parse_request_path_parameters_as(_PathParams, request) + search_body = await parse_request_body_as( + model_schema_cls=SearchBodyParams, request=request + ) + + task_uuid = await get_task_manager(request.app).submit_task( + ExecutionMetadata( + name=SEARCH_TASK_NAME, + ), + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() + ), + user_id=_req_ctx.user_id, + name_pattern=search_body.filters.name_pattern, + modified_at=( + ( + search_body.filters.modified_at.from_, + search_body.filters.modified_at.until, + ) + if search_body.filters.modified_at + else None + ), + project_id=search_body.filters.project_id, + ) + + _task_id = f"{task_uuid}" + return create_data_response( + TaskGet( + task_id=_task_id, + task_name=SEARCH_TASK_NAME, + status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=_task_id)))}", + abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=_task_id)))}", + stream_href=f"{request.url.with_path(str(request.app.router['get_async_job_stream'].url_for(task_id=_task_id)))}", + ), + status=status.HTTP_202_ACCEPTED, + ) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py index e1d6f4010f7b..54d539a26339 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest.py @@ -14,6 +14,7 @@ ) from servicelib.aiohttp.requests_validation import ( parse_request_path_parameters_as, + parse_request_query_parameters_as, ) from servicelib.aiohttp.rest_responses import ( create_data_response, @@ -28,7 +29,7 @@ from ...models import AuthenticatedRequestContext, WebServerOwnerMetadata from .. import _tasks_service from ._rest_exceptions import handle_rest_requests_exceptions -from ._rest_schemas import TaskPathParams +from ._rest_schemas import TaskPathParams, TaskStreamQueryParams, TaskStreamResponse log = logging.getLogger(__name__) @@ -96,7 +97,6 @@ async def get_async_jobs(request: web.Request) -> web.Response: @login_required @handle_rest_requests_exceptions async def get_async_job_status(request: web.Request) -> web.Response: - _req_ctx = AuthenticatedRequestContext.model_validate(request) _path_params = parse_request_path_parameters_as(TaskPathParams, request) @@ -131,7 +131,6 @@ async def get_async_job_status(request: web.Request) -> web.Response: @login_required @handle_rest_requests_exceptions async def cancel_async_job(request: web.Request) -> web.Response: - _req_ctx = AuthenticatedRequestContext.model_validate(request) _path_params = parse_request_path_parameters_as(TaskPathParams, request) @@ -175,3 +174,35 @@ async def get_async_job_result(request: web.Request) -> web.Response: TaskResult(result=task_result.result, error=None), status=status.HTTP_200_OK, ) + + +@routes.get( + _task_prefix + "/{task_id}/stream", + name="get_async_job_stream", +) +@login_required +@handle_rest_requests_exceptions +async def get_async_job_stream(request: web.Request) -> web.Response: + + _req_ctx = AuthenticatedRequestContext.model_validate(request) + _path_params = parse_request_path_parameters_as(TaskPathParams, request) + _query_params: TaskStreamQueryParams = parse_request_query_parameters_as( + TaskStreamQueryParams, request + ) + + task_result, end = await _tasks_service.pull_task_stream_items( + get_task_manager(request.app), + owner_metadata=OwnerMetadata.model_validate( + WebServerOwnerMetadata( + user_id=_req_ctx.user_id, + product_name=_req_ctx.product_name, + ).model_dump() + ), + task_uuid=_path_params.task_id, + limit=_query_params.limit, + ) + + return create_data_response( + TaskStreamResponse(items=[r.data for r in task_result], end=end), + status=status.HTTP_200_OK, + ) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_schemas.py b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_schemas.py index 7f25cfca51da..2bd8c8a323b6 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_schemas.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_schemas.py @@ -1,8 +1,22 @@ from uuid import UUID +from models_library.rest_pagination import ( + DEFAULT_NUMBER_OF_ITEMS_PER_PAGE, + PageLimitInt, +) from pydantic import BaseModel, ConfigDict class TaskPathParams(BaseModel): task_id: UUID model_config = ConfigDict(extra="forbid", frozen=True) + + +class TaskStreamQueryParams(BaseModel): + limit: PageLimitInt = DEFAULT_NUMBER_OF_ITEMS_PER_PAGE + + +class TaskStreamResponse(BaseModel): + items: list[dict] + end: bool + model_config = ConfigDict(extra="forbid", frozen=True) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_utils.py b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_utils.py new file mode 100644 index 000000000000..36ab91a85872 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/tasks/_controller/_rest_utils.py @@ -0,0 +1,24 @@ +from aiohttp import web +from models_library.rest_pagination import Page +from models_library.rest_pagination_utils import _URLType, paginate_stream_chunk +from servicelib.celery.models import TaskStreamItem +from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON +from servicelib.rest_constants import RESPONSE_MODEL_POLICY + + +def create_page_response( + items: list[TaskStreamItem], + request_url: _URLType, +) -> web.Response: + page = Page[TaskStreamItem].model_validate( + paginate_stream_chunk( + chunk=items, + request_url=request_url, + cursor=0, + has_more=True, + ) + ) + return web.Response( + text=page.model_dump_json(**RESPONSE_MODEL_POLICY), + content_type=MIMETYPE_APPLICATION_JSON, + ) diff --git a/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py b/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py index 36c5c382b255..cf28305b5e82 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_tasks_service.py @@ -1,4 +1,6 @@ import logging +from datetime import UTC, datetime, timedelta +from typing import Final from celery_library.errors import ( TaskManagerError, @@ -17,9 +19,11 @@ JobNotDoneError, JobSchedulerError, ) +from pydantic import NonNegativeFloat from servicelib.celery.models import ( OwnerMetadata, TaskState, + TaskStreamItem, TaskUUID, ) from servicelib.celery.task_manager import TaskManager @@ -28,6 +32,9 @@ _logger = logging.getLogger(__name__) +_STREAM_STALL_THRESHOLD: Final[NonNegativeFloat] = timedelta(minutes=1).total_seconds() + + async def cancel_task( task_manager: TaskManager, *, @@ -111,6 +118,34 @@ async def get_task_status( ) +async def pull_task_stream_items( + task_manager: TaskManager, + *, + owner_metadata: OwnerMetadata, + task_uuid: TaskUUID, + limit: int = 50, +) -> tuple[list[TaskStreamItem], bool]: + try: + results, end, last_update = await task_manager.pull_task_stream_items( + owner_metadata=owner_metadata, + task_uuid=task_uuid, + limit=limit, + ) + except TaskNotFoundError as exc: + raise JobMissingError(job_id=task_uuid) from exc + except TaskManagerError as exc: + raise JobSchedulerError(exc=f"{exc}") from exc + + if not end and last_update: + delta = datetime.now(UTC) - last_update + if delta.total_seconds() > _STREAM_STALL_THRESHOLD: + raise JobSchedulerError( + exc=f"Task seems stalled since {delta.total_seconds()} seconds" + ) + + return results, end + + async def list_tasks( task_manager: TaskManager, *, diff --git a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py index ecbad97651e1..bb155e47c8cf 100644 --- a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py +++ b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py @@ -472,7 +472,7 @@ async def test_export_data( paths=[Path(f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}")] ) response = await client.post( - f"/{API_VERSION}/storage/locations/0/export-data", data=_body.model_dump_json() + f"/{API_VERSION}/storage/locations/0:export-data", data=_body.model_dump_json() ) assert response.status == expected_status if response.status == status.HTTP_202_ACCEPTED: @@ -614,7 +614,7 @@ async def test_get_async_job_links( paths=[PathToExport(f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}")] ) response = await client.post( - f"/{API_VERSION}/storage/locations/0/export-data", data=_body.model_dump_json() + f"/{API_VERSION}/storage/locations/0:export-data", data=_body.model_dump_json() ) assert response.status == status.HTTP_202_ACCEPTED response_body_data = Envelope[TaskGet].model_validate(await response.json()).data