Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobFilter,
AsyncJobGet,
Expand Down Expand Up @@ -41,6 +43,7 @@ async def start_export_data(
user_id: UserID,
product_name: ProductName,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
) -> tuple[AsyncJobGet, AsyncJobFilter]:
job_filter = get_async_job_filter(user_id=user_id, product_name=product_name)
async_job_rpc_get = await submit(
Expand All @@ -49,5 +52,6 @@ async def start_export_data(
method_name=TypeAdapter(RPCMethodName).validate_python("start_export_data"),
job_filter=job_filter,
paths_to_export=paths_to_export,
export_as=export_as,
)
return async_job_rpc_get, job_filter
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from aws_library.s3._models import S3ObjectKey
from celery import Task # type: ignore[import-untyped]
from celery_library.utils import get_app_server
from models_library.api_schemas_storage.storage_schemas import FoldersBody
from models_library.api_schemas_storage.storage_schemas import (
FoldersBody,
LinkType,
PresignedLink,
)
from models_library.api_schemas_webserver.storage import PathToExport
from models_library.progress_bar import ProgressReport
from models_library.projects_nodes_io import StorageFileID
Expand Down Expand Up @@ -100,3 +104,27 @@ async def _progress_cb(report: ProgressReport) -> None:
return await dsm.create_s3_export(
user_id, object_keys, progress_bar=progress_bar
)


async def export_data_as_download_link(
task: Task,
task_id: TaskID,
*,
user_id: UserID,
paths_to_export: list[PathToExport],
) -> PresignedLink:
"""
AccessRightError: in case user can't access project
"""
s3_object = await export_data(
task=task, task_id=task_id, user_id=user_id, paths_to_export=paths_to_export
)

dsm = get_dsm_provider(get_app_server(task.app).app).get(
SimcoreS3DataManager.get_location_id()
)

download_link = await dsm.create_file_download_link(
user_id=user_id, file_id=s3_object, link_type=LinkType.PRESIGNED
)
return PresignedLink(link=download_link)
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,33 @@
from models_library.api_schemas_storage.storage_schemas import (
FileUploadCompletionBody,
FoldersBody,
PresignedLink,
)
from servicelib.logging_utils import log_context

from ...models import FileMetaData
from ._files import complete_upload_file
from ._paths import compute_path_size, delete_paths
from ._simcore_s3 import deep_copy_files_from_project, export_data
from ._simcore_s3 import (
deep_copy_files_from_project,
export_data,
export_data_as_download_link,
)

_logger = logging.getLogger(__name__)


def setup_worker_tasks(app: Celery) -> None:
register_celery_types()
register_pydantic_types(FileUploadCompletionBody, FileMetaData, FoldersBody)
register_pydantic_types(
FileUploadCompletionBody, FileMetaData, FoldersBody, PresignedLink
)

with log_context(_logger, logging.INFO, msg="worker task registration"):
register_task(app, export_data, dont_autoretry_for=(AccessRightError,))
register_task(
app, export_data_as_download_link, dont_autoretry_for=(AccessRightError,)
)
register_task(app, compute_path_size)
register_task(app, complete_upload_file)
register_task(app, delete_paths)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobFilter,
AsyncJobGet,
Expand All @@ -8,7 +10,11 @@
from servicelib.celery.task_manager import TaskManager
from servicelib.rabbitmq import RPCRouter

from .._worker_tasks._simcore_s3 import deep_copy_files_from_project, export_data
from .._worker_tasks._simcore_s3 import (
deep_copy_files_from_project,
export_data,
export_data_as_download_link,
)

router = RPCRouter()

Expand Down Expand Up @@ -38,8 +44,14 @@ async def start_export_data(
task_manager: TaskManager,
job_filter: AsyncJobFilter,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
) -> AsyncJobGet:
task_name = export_data.__name__
if export_as == "path":
task_name = export_data.__name__
elif export_as == "download_link":
task_name = export_data_as_download_link.__name__
else:
raise ValueError(f"Invalid export_as value: {export_as}")
task_filter = TaskFilter.model_validate(job_filter.model_dump())
task_uuid = await task_manager.submit_task(
task_metadata=TaskMetadata(
Expand Down
62 changes: 56 additions & 6 deletions services/storage/tests/unit/test_rpc_handlers_simcore_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections.abc import Awaitable, Callable
from copy import deepcopy
from pathlib import Path
from typing import Any
from typing import Any, Literal
from unittest.mock import Mock

import httpx
Expand All @@ -30,6 +30,7 @@
from models_library.api_schemas_storage.storage_schemas import (
FileMetaDataGet,
FoldersBody,
PresignedLink,
)
from models_library.api_schemas_webserver.storage import PathToExport
from models_library.basic_types import SHA256Str
Expand All @@ -52,6 +53,7 @@
from pytest_simcore.helpers.storage_utils_project import clone_project_data
from servicelib.aiohttp import status
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
from servicelib.rabbitmq._errors import RPCServerError
from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import wait_and_get_result
from servicelib.rabbitmq.rpc_interfaces.storage.simcore_s3 import (
copy_folders_from_project,
Expand Down Expand Up @@ -514,9 +516,10 @@ async def _request_start_export_data(
user_id: UserID,
product_name: ProductName,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
*,
client_timeout: datetime.timedelta = datetime.timedelta(seconds=60),
) -> dict[str, Any]:
) -> str:
with log_context(
logging.INFO,
f"Data export form {paths_to_export=}",
Expand All @@ -526,6 +529,7 @@ async def _request_start_export_data(
user_id=user_id,
product_name=product_name,
paths_to_export=paths_to_export,
export_as=export_as,
)

async for async_job_result in wait_and_get_result(
Expand Down Expand Up @@ -572,6 +576,10 @@ def task_progress_spy(mocker: MockerFixture) -> Mock:
],
ids=str,
)
@pytest.mark.parametrize(
"export_as",
["path", "download_link"],
)
async def test_start_export_data(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
Expand All @@ -589,6 +597,7 @@ async def test_start_export_data(
],
project_params: ProjectWithFilesParams,
task_progress_spy: Mock,
export_as: Literal["path", "download_link"],
):
_, src_projects_list = await random_project_with_files(project_params)

Expand All @@ -606,18 +615,32 @@ async def test_start_export_data(
user_id,
product_name,
paths_to_export=list(nodes_in_project_to_export),
export_as=export_as,
)

assert re.fullmatch(
rf"^exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip$",
result,
)
if export_as == "path":
assert re.fullmatch(
rf"^exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip$",
result,
)
elif export_as == "download_link":
link = PresignedLink.model_validate(result).link
assert re.search(
rf"exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip",
f"{link}",
)
else:
pytest.fail(f"Unexpected export_as value: {export_as}")

progress_updates = [x[0][2].actual_value for x in task_progress_spy.call_args_list]
assert progress_updates[0] == 0
assert progress_updates[-1] == 1


@pytest.mark.parametrize(
"export_as",
["path", "download_link"],
)
async def test_start_export_data_access_error(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
Expand All @@ -626,6 +649,7 @@ async def test_start_export_data_access_error(
user_id: UserID,
product_name: ProductName,
faker: Faker,
export_as: Literal["path", "download_link"],
):
path_to_export = TypeAdapter(PathToExport).validate_python(
f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}"
Expand All @@ -637,9 +661,35 @@ async def test_start_export_data_access_error(
product_name,
paths_to_export=[path_to_export],
client_timeout=datetime.timedelta(seconds=60),
export_as=export_as,
)

assert isinstance(exc.value, JobError)
assert exc.value.exc_type == "AccessRightError"
assert f" {user_id} " in f"{exc.value}"
assert f" {path_to_export} " in f"{exc.value}"


async def test_start_export_invalid_export_format(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
with_storage_celery_worker: TestWorkController,
storage_rabbitmq_rpc_client: RabbitMQRPCClient,
user_id: UserID,
product_name: ProductName,
faker: Faker,
):
path_to_export = TypeAdapter(PathToExport).validate_python(
f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}"
)
with pytest.raises(RPCServerError) as exc:
await _request_start_export_data(
storage_rabbitmq_rpc_client,
user_id,
product_name,
paths_to_export=[path_to_export],
client_timeout=datetime.timedelta(seconds=60),
export_as="invalid_format", # type: ignore
)

assert exc.value.exc_type == "builtins.ValueError"
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def allow_only_simcore(cls, v: int) -> int:
user_id=_req_ctx.user_id,
product_name=_req_ctx.product_name,
paths_to_export=export_data_post.paths,
export_as="path",
)
_job_id = f"{async_job_rpc_get.job_id}"
return create_data_response(
Expand Down
Loading