Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime
from typing import Annotated, Any, NamedTuple
from typing import Any, NamedTuple

from models_library.services_types import ServiceRunID
from pydantic import (
AnyUrl,
BaseModel,
BeforeValidator,
ConfigDict,
PositiveInt,
)
Expand Down Expand Up @@ -62,20 +63,16 @@ class ComputationRunRpcGetPage(NamedTuple):
total: PositiveInt


def _none_to_zero_float_pre_validator(value: Any):
if value is None:
return 0.0
return value


class ComputationTaskRpcGet(BaseModel):
project_uuid: ProjectID
node_id: NodeID
state: RunningState
progress: Annotated[float, BeforeValidator(_none_to_zero_float_pre_validator)]
progress: float
image: dict[str, Any]
started_at: datetime | None
ended_at: datetime | None
log_download_link: AnyUrl | None
service_run_id: ServiceRunID

model_config = ConfigDict(
json_schema_extra={
Expand All @@ -92,6 +89,8 @@ class ComputationTaskRpcGet(BaseModel):
},
"started_at": "2023-01-11 13:11:47.293595",
"ended_at": "2023-01-11 13:11:47.293595",
"log_download_link": "https://example.com/logs",
"service_run_id": "comp_1_12e0c8b2-bad6-40fb-9948-8dec4f65d4d9_1",
}
]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import datetime
from decimal import Decimal
from typing import Annotated, Any

from common_library.basic_types import DEFAULT_FACTORY
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
Expand Down Expand Up @@ -123,3 +125,6 @@ class ComputationTaskRestGet(OutputSchema):
image: dict[str, Any]
started_at: datetime | None
ended_at: datetime | None
log_download_link: AnyUrl | None
node_name: str
osparc_credits: Decimal | None
24 changes: 24 additions & 0 deletions packages/models-library/src/models_library/computations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from datetime import datetime
from decimal import Decimal
from typing import Any

from pydantic import AnyUrl, BaseModel

from .projects import ProjectID
from .projects_nodes_io import NodeID
from .projects_state import RunningState


class ComputationTaskWithAttributes(BaseModel):
project_uuid: ProjectID
node_id: NodeID
state: RunningState
progress: float
image: dict[str, Any]
started_at: datetime | None
ended_at: datetime | None
log_download_link: AnyUrl | None

# Attributes added by the webserver
node_name: str
osparc_credits: Decimal | None
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from decimal import Decimal
from typing import Final

from models_library.api_schemas_resource_usage_tracker import (
Expand All @@ -12,6 +13,7 @@
from models_library.projects import ProjectID
from models_library.rabbitmq_basic_types import RPCMethodName
from models_library.resource_tracker import CreditTransactionStatus
from models_library.services_types import ServiceRunID
from models_library.wallets import WalletID
from pydantic import NonNegativeInt, TypeAdapter

Expand Down Expand Up @@ -82,3 +84,21 @@ async def pay_project_debt(
new_wallet_transaction=new_wallet_transaction,
timeout_s=_DEFAULT_TIMEOUT_S,
)


@log_decorator(_logger, level=logging.DEBUG)
async def get_transaction_current_credits_by_service_run_id(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
service_run_id: ServiceRunID,
) -> Decimal:
result = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python(
"get_transaction_current_credits_by_service_run_id"
),
service_run_id=service_run_id,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, Decimal) # nosec
return result
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from common_library.errors_classes import OsparcErrorMixin


class LicensesBaseError(OsparcErrorMixin, Exception):
...
class LicensesBaseError(OsparcErrorMixin, Exception): ...


class NotEnoughAvailableSeatsError(LicensesBaseError):
Expand Down Expand Up @@ -36,11 +35,13 @@ class WalletTransactionError(OsparcErrorMixin, Exception):
msg_template = "{msg}"


class CreditTransactionNotFoundError(OsparcErrorMixin, Exception): ...


### Pricing Plans Error


class PricingPlanBaseError(OsparcErrorMixin, Exception):
...
class PricingPlanBaseError(OsparcErrorMixin, Exception): ...


class PricingUnitDuplicationError(PricingPlanBaseError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
from models_library.projects_nodes_io import NodeID
from models_library.users import UserID
from servicelib.utils import logged_gather
from simcore_sdk.node_ports_common.exceptions import NodeportsException
from simcore_sdk.node_ports_v2 import FileLinkType
from starlette import status

from ...models.comp_pipelines import CompPipelineAtDB
from ...models.comp_tasks import CompTaskAtDB
from ...modules.db.repositories.comp_pipelines import CompPipelinesRepository
from ...modules.db.repositories.comp_tasks import CompTasksRepository
from ...utils.dask import get_service_log_file_download_link
from ...utils import dask as dask_utils
from ..dependencies.database import get_repository

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,31 +79,6 @@ async def analyze_pipeline(
return PipelineInfo(pipeline_dag, all_tasks, filtered_tasks)


async def _get_task_log_file(
user_id: UserID, project_id: ProjectID, node_id: NodeID
) -> TaskLogFileGet:
try:
log_file_url = await get_service_log_file_download_link(
user_id, project_id, node_id, file_link_type=FileLinkType.PRESIGNED
)

except NodeportsException as err:
# Unexpected error: Cannot determine the cause of failure
# to get donwload link and cannot handle it automatically.
# Will treat it as "not available" and log a warning
log_file_url = None
log.warning(
"Failed to get log-file of %s: %s.",
f"{user_id=}/{project_id=}/{node_id=}",
err,
)

return TaskLogFileGet(
task_id=node_id,
download_link=log_file_url,
)


# ROUTES HANDLERS --------------------------------------------------------------


Expand Down Expand Up @@ -133,7 +106,7 @@ async def get_all_tasks_log_files(

tasks_logs_files: list[TaskLogFileGet] = await logged_gather(
*[
_get_task_log_file(user_id, project_id, node_id)
dask_utils.get_task_log_file(user_id, project_id, node_id)
for node_id in iter_task_ids
],
reraise=True,
Expand Down Expand Up @@ -165,7 +138,7 @@ async def get_task_log_file(
detail=[f"No task_id={node_uuid} found under computation {project_id}"],
)

return await _get_task_log_file(user_id, project_id, node_uuid)
return await dask_utils.get_task_log_file(user_id, project_id, node_uuid)


@router.post(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import asyncio

# pylint: disable=too-many-arguments
from fastapi import FastAPI
from models_library.api_schemas_directorv2.comp_runs import (
ComputationRunRpcGetPage,
ComputationTaskRpcGet,
ComputationTaskRpcGetPage,
)
from models_library.api_schemas_directorv2.computations import TaskLogFileGet
from models_library.products import ProductName
from models_library.projects import ProjectID
from models_library.rest_ordering import OrderBy
from models_library.services_types import ServiceRunID
from models_library.users import UserID
from servicelib.rabbitmq import RPCRouter
from simcore_postgres_database.models import comp_runs
from simcore_service_director_v2.models.comp_tasks import ComputationTaskForRpcDBGet

from ...modules.db.repositories.comp_runs import CompRunsRepository
from ...modules.db.repositories.comp_tasks import CompTasksRepository
from ...utils import dask as dask_utils

router = RPCRouter()

Expand Down Expand Up @@ -42,6 +50,18 @@ async def list_computations_latest_iteration_page(
)


async def _fetch_task_log(
user_id: UserID, project_id: ProjectID, task: ComputationTaskForRpcDBGet
) -> TaskLogFileGet | None:
if not task.state.is_running():
return await dask_utils.get_task_log_file(
user_id=user_id,
project_id=project_id,
node_id=task.node_id,
)
return None


@router.expose(reraise_if_error_type=())
async def list_computations_latest_iteration_tasks_page(
app: FastAPI,
Expand All @@ -59,13 +79,42 @@ async def list_computations_latest_iteration_tasks_page(
assert user_id # nosec NOTE: Whether user_id has access to the project was checked in the webserver

comp_tasks_repo = CompTasksRepository.instance(db_engine=app.state.engine)
total, comp_runs = await comp_tasks_repo.list_computational_tasks_rpc_domain(
comp_runs_repo = CompRunsRepository.instance(db_engine=app.state.engine)

comp_latest_run = await comp_runs_repo.get(
user_id=user_id, project_id=project_id, iteration=None # Returns last iteration
)

total, comp_tasks = await comp_tasks_repo.list_computational_tasks_rpc_domain(
project_id=project_id,
offset=offset,
limit=limit,
order_by=order_by,
)

# Run all log fetches concurrently
log_files = await asyncio.gather(
*(_fetch_task_log(user_id, project_id, task) for task in comp_tasks)
)

comp_tasks_output = [
ComputationTaskRpcGet(
project_uuid=task.project_uuid,
node_id=task.node_id,
state=task.state,
progress=task.progress,
image=task.image,
started_at=task.started_at,
ended_at=task.ended_at,
log_download_link=log_file.download_link if log_file else None,
service_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational(
user_id, project_id, task.node_id, comp_latest_run.iteration
),
)
for task, log_file in zip(comp_tasks, log_files, strict=True)
]

return ComputationTaskRpcGetPage(
items=comp_runs,
items=comp_tasks_output,
total=total,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from models_library.services_resources import BootMode
from pydantic import (
BaseModel,
BeforeValidator,
ByteSize,
ConfigDict,
Field,
Expand Down Expand Up @@ -257,3 +258,19 @@ def to_db_model(self, **exclusion_rules) -> dict[str, Any]:
]
},
)


def _none_to_zero_float_pre_validator(value: Any):
if value is None:
return 0.0
return value


class ComputationTaskForRpcDBGet(BaseModel):
project_uuid: ProjectID
node_id: NodeID
state: RunningState
progress: Annotated[float, BeforeValidator(_none_to_zero_float_pre_validator)]
image: dict[str, Any]
started_at: dt.datetime | None
ended_at: dt.datetime | None
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import arrow
import sqlalchemy as sa
from models_library.api_schemas_directorv2.comp_runs import ComputationTaskRpcGet
from models_library.basic_types import IDStr
from models_library.errors import ErrorDict
from models_library.projects import ProjectAtDB, ProjectID
Expand All @@ -20,7 +19,7 @@
from sqlalchemy.dialects.postgresql import insert

from .....core.errors import ComputationalTaskNotFoundError
from .....models.comp_tasks import CompTaskAtDB
from .....models.comp_tasks import CompTaskAtDB, ComputationTaskForRpcDBGet
from .....modules.resource_usage_tracker_client import ResourceUsageTrackerClient
from .....utils.computations import to_node_class
from .....utils.db import DB_TO_RUNNING_STATE, RUNNING_STATE_TO_DB
Expand Down Expand Up @@ -85,7 +84,7 @@ async def list_computational_tasks_rpc_domain(
limit: int = 20,
# ordering
order_by: OrderBy | None = None,
) -> tuple[int, list[ComputationTaskRpcGet]]:
) -> tuple[int, list[ComputationTaskForRpcDBGet]]:
if order_by is None:
order_by = OrderBy(field=IDStr("task_id")) # default ordering

Expand Down Expand Up @@ -126,7 +125,7 @@ async def list_computational_tasks_rpc_domain(
total_count = await conn.scalar(count_query)

items = [
ComputationTaskRpcGet.model_validate(
ComputationTaskForRpcDBGet.model_validate(
{
**row,
"state": DB_TO_RUNNING_STATE[row["state"]], # Convert the state
Expand Down
Loading
Loading