Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobFilter,
AsyncJobGet,
Expand Down Expand Up @@ -41,6 +43,7 @@ async def start_export_data(
user_id: UserID,
product_name: ProductName,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
) -> tuple[AsyncJobGet, AsyncJobFilter]:
job_filter = get_async_job_filter(user_id=user_id, product_name=product_name)
async_job_rpc_get = await submit(
Expand All @@ -49,5 +52,6 @@ async def start_export_data(
method_name=TypeAdapter(RPCMethodName).validate_python("start_export_data"),
job_filter=job_filter,
paths_to_export=paths_to_export,
export_as=export_as,
)
return async_job_rpc_get, job_filter
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aws_library.s3._models import S3ObjectKey
from celery import Task # type: ignore[import-untyped]
from celery_library.utils import get_app_server
from models_library.api_schemas_storage.storage_schemas import FoldersBody
from models_library.api_schemas_storage.storage_schemas import FoldersBody, LinkType
from models_library.api_schemas_webserver.storage import PathToExport
from models_library.progress_bar import ProgressReport
from models_library.projects_nodes_io import StorageFileID
Expand Down Expand Up @@ -100,3 +100,27 @@ async def _progress_cb(report: ProgressReport) -> None:
return await dsm.create_s3_export(
user_id, object_keys, progress_bar=progress_bar
)


async def export_data_as_download_link(
task: Task,
task_id: TaskID,
*,
user_id: UserID,
paths_to_export: list[PathToExport],
) -> str:
"""
AccessRightError: in case user can't access project
"""
s3_object = await export_data(
task=task, task_id=task_id, user_id=user_id, paths_to_export=paths_to_export
)

dsm = get_dsm_provider(get_app_server(task.app).app).get(
SimcoreS3DataManager.get_location_id()
)

download_link = await dsm.create_file_download_link(
user_id=user_id, file_id=s3_object, link_type=LinkType.PRESIGNED
)
return f"{download_link}"
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from ...models import FileMetaData
from ._files import complete_upload_file
from ._paths import compute_path_size, delete_paths
from ._simcore_s3 import deep_copy_files_from_project, export_data
from ._simcore_s3 import (
deep_copy_files_from_project,
export_data,
export_data_as_download_link,
)

_logger = logging.getLogger(__name__)

Expand All @@ -24,6 +28,9 @@ def setup_worker_tasks(app: Celery) -> None:

with log_context(_logger, logging.INFO, msg="worker task registration"):
register_task(app, export_data, dont_autoretry_for=(AccessRightError,))
register_task(
app, export_data_as_download_link, dont_autoretry_for=(AccessRightError,)
)
register_task(app, compute_path_size)
register_task(app, complete_upload_file)
register_task(app, delete_paths)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobFilter,
AsyncJobGet,
Expand All @@ -8,7 +10,11 @@
from servicelib.celery.task_manager import TaskManager
from servicelib.rabbitmq import RPCRouter

from .._worker_tasks._simcore_s3 import deep_copy_files_from_project, export_data
from .._worker_tasks._simcore_s3 import (
deep_copy_files_from_project,
export_data,
export_data_as_download_link,
)

router = RPCRouter()

Expand Down Expand Up @@ -38,8 +44,14 @@ async def start_export_data(
task_manager: TaskManager,
job_filter: AsyncJobFilter,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
) -> AsyncJobGet:
task_name = export_data.__name__
if export_as == "path":
task_name = export_data.__name__
elif export_as == "download_link":
task_name = export_data_as_download_link.__name__
else:
raise ValueError(f"Invalid export_as value: {export_as}")
task_filter = TaskFilter.model_validate(job_filter.model_dump())
task_uuid = await task_manager.submit_task(
task_metadata=TaskMetadata(
Expand Down
63 changes: 56 additions & 7 deletions services/storage/tests/unit/test_rpc_handlers_simcore_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections.abc import Awaitable, Callable
from copy import deepcopy
from pathlib import Path
from typing import Any
from typing import Any, Literal
from unittest.mock import Mock

import httpx
Expand All @@ -36,7 +36,7 @@
from models_library.products import ProductName
from models_library.projects_nodes_io import NodeID, NodeIDStr, SimcoreS3FileID
from models_library.users import UserID
from pydantic import ByteSize, TypeAdapter
from pydantic import ByteSize, HttpUrl, TypeAdapter
from pytest_mock import MockerFixture
from pytest_simcore.helpers.fastapi import url_from_operation_id
from pytest_simcore.helpers.httpx_assert_checks import assert_status
Expand All @@ -52,6 +52,7 @@
from pytest_simcore.helpers.storage_utils_project import clone_project_data
from servicelib.aiohttp import status
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
from servicelib.rabbitmq._errors import RPCServerError
from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import wait_and_get_result
from servicelib.rabbitmq.rpc_interfaces.storage.simcore_s3 import (
copy_folders_from_project,
Expand Down Expand Up @@ -514,9 +515,10 @@ async def _request_start_export_data(
user_id: UserID,
product_name: ProductName,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
*,
client_timeout: datetime.timedelta = datetime.timedelta(seconds=60),
) -> dict[str, Any]:
) -> str:
with log_context(
logging.INFO,
f"Data export form {paths_to_export=}",
Expand All @@ -526,6 +528,7 @@ async def _request_start_export_data(
user_id=user_id,
product_name=product_name,
paths_to_export=paths_to_export,
export_as=export_as,
)

async for async_job_result in wait_and_get_result(
Expand Down Expand Up @@ -572,6 +575,10 @@ def task_progress_spy(mocker: MockerFixture) -> Mock:
],
ids=str,
)
@pytest.mark.parametrize(
"export_as",
["path", "download_link"],
)
async def test_start_export_data(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
Expand All @@ -589,6 +596,7 @@ async def test_start_export_data(
],
project_params: ProjectWithFilesParams,
task_progress_spy: Mock,
export_as: Literal["path", "download_link"],
):
_, src_projects_list = await random_project_with_files(project_params)

Expand All @@ -606,18 +614,32 @@ async def test_start_export_data(
user_id,
product_name,
paths_to_export=list(nodes_in_project_to_export),
export_as=export_as,
)

assert re.fullmatch(
rf"^exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip$",
result,
)
if export_as == "path":
assert re.fullmatch(
rf"^exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip$",
result,
)
elif export_as == "download_link":
_ = HttpUrl(result)
assert re.search(
rf"exports/{user_id}/[0-9a-fA-F]{{8}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{4}}-[0-9a-fA-F]{{12}}\.zip",
result,
)
else:
pytest.fail(f"Unexpected export_as value: {export_as}")

progress_updates = [x[0][2].actual_value for x in task_progress_spy.call_args_list]
assert progress_updates[0] == 0
assert progress_updates[-1] == 1


@pytest.mark.parametrize(
"export_as",
["path", "download_link"],
)
async def test_start_export_data_access_error(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
Expand All @@ -626,6 +648,7 @@ async def test_start_export_data_access_error(
user_id: UserID,
product_name: ProductName,
faker: Faker,
export_as: Literal["path", "download_link"],
):
path_to_export = TypeAdapter(PathToExport).validate_python(
f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}"
Expand All @@ -637,9 +660,35 @@ async def test_start_export_data_access_error(
product_name,
paths_to_export=[path_to_export],
client_timeout=datetime.timedelta(seconds=60),
export_as=export_as,
)

assert isinstance(exc.value, JobError)
assert exc.value.exc_type == "AccessRightError"
assert f" {user_id} " in f"{exc.value}"
assert f" {path_to_export} " in f"{exc.value}"


async def test_start_export_invalid_export_format(
initialized_app: FastAPI,
short_dsm_cleaner_interval: int,
with_storage_celery_worker: TestWorkController,
storage_rabbitmq_rpc_client: RabbitMQRPCClient,
user_id: UserID,
product_name: ProductName,
faker: Faker,
):
path_to_export = TypeAdapter(PathToExport).validate_python(
f"{faker.uuid4()}/{faker.uuid4()}/{faker.file_name()}"
)
with pytest.raises(RPCServerError) as exc:
await _request_start_export_data(
storage_rabbitmq_rpc_client,
user_id,
product_name,
paths_to_export=[path_to_export],
client_timeout=datetime.timedelta(seconds=60),
export_as="invalid_format", # type: ignore
)

assert exc.value.exc_type == "builtins.ValueError"
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def allow_only_simcore(cls, v: int) -> int:
user_id=_req_ctx.user_id,
product_name=_req_ctx.product_name,
paths_to_export=export_data_post.paths,
export_as="path",
)
_job_id = f"{async_job_rpc_get.job_id}"
return create_data_response(
Expand Down
Loading