Skip to content
Merged
12 changes: 12 additions & 0 deletions api/specs/web-server/_computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from models_library.api_schemas_webserver.computations import (
ComputationGet,
ComputationPathParams,
ComputationRunPathParams,
ComputationRunRestGet,
ComputationRunWithFiltersListQueryParams,
ComputationStart,
ComputationStarted,
ComputationTaskRestGet,
Expand Down Expand Up @@ -68,7 +70,17 @@ async def stop_computation(_path: Annotated[ComputationPathParams, Depends()]):
response_model=Envelope[list[ComputationRunRestGet]],
)
async def list_computations_latest_iteration(
_query: Annotated[as_query(ComputationRunWithFiltersListQueryParams), Depends()],
): ...


@router.get(
"/computations/{project_id}/iterations",
response_model=Envelope[list[ComputationRunRestGet]],
)
async def list_computation_iterations(
_query: Annotated[as_query(ComputationRunListQueryParams), Depends()],
_path: Annotated[ComputationRunPathParams, Depends()],
): ...


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ class ComputationStarted(OutputSchemaWithoutCamelCase):
class ComputationRunListQueryParams(
PageQueryParameters,
ComputationRunListOrderParams, # type: ignore[misc, valid-type]
):
): ...


class ComputationRunWithFiltersListQueryParams(ComputationRunListQueryParams):
filter_only_running: bool = Field(
default=False,
description="If true, only running computations are returned",
Expand All @@ -100,6 +103,11 @@ class ComputationRunRestGet(OutputSchema):
project_custom_metadata: dict[str, Any]


class ComputationRunPathParams(BaseModel):
project_id: ProjectID
model_config = ConfigDict(populate_by_name=True, extra="forbid")


### Computation Task


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ async def list_computations_latest_iteration_page(
return result


@log_decorator(_logger, level=logging.DEBUG)
async def list_computations_iterations_page(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
product_name: ProductName,
user_id: UserID,
project_id: ProjectID,
# pagination
offset: int = 0,
limit: int = 20,
# ordering
order_by: OrderBy | None = None,
) -> ComputationRunRpcGetPage:
result = await rabbitmq_rpc_client.request(
DIRECTOR_V2_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("list_computations_iterations_page"),
product_name=product_name,
user_id=user_id,
project_id=project_id,
offset=offset,
limit=limit,
order_by=order_by,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, ComputationRunRpcGetPage) # nosec
return result


@log_decorator(_logger, level=logging.DEBUG)
async def list_computations_latest_iteration_tasks_page(
rabbitmq_rpc_client: RabbitMQRPCClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,36 @@ async def list_computations_latest_iteration_page(
)


@router.expose(reraise_if_error_type=())
async def list_computations_iterations_page(
app: FastAPI,
*,
product_name: ProductName,
user_id: UserID,
project_id: ProjectID,
# pagination
offset: int = 0,
limit: int = 20,
# ordering
order_by: OrderBy | None = None,
) -> ComputationRunRpcGetPage:
comp_runs_repo = CompRunsRepository.instance(db_engine=app.state.engine)
total, comp_runs_output = (
await comp_runs_repo.list_for_user_and_project_all_iterations(
product_name=product_name,
user_id=user_id,
project_id=project_id,
offset=offset,
limit=limit,
order_by=order_by,
)
)
return ComputationRunRpcGetPage(
items=comp_runs_output,
total=total,
)


async def _fetch_task_log(
user_id: UserID, project_id: ProjectID, task: ComputationTaskForRpcDBGet
) -> TaskLogFileGet | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ async def list_(
)
]

_COMPUTATION_RUNS_RPC_GET_COLUMNS = [ # noqa: RUF012
comp_runs.c.project_uuid,
comp_runs.c.iteration,
comp_runs.c.result.label("state"),
comp_runs.c.metadata.label("info"),
comp_runs.c.created.label("submitted_at"),
comp_runs.c.started.label("started_at"),
comp_runs.c.ended.label("ended_at"),
]

async def list_for_user__only_latest_iterations(
self,
*,
Expand All @@ -212,13 +222,7 @@ async def list_for_user__only_latest_iterations(
order_by = OrderBy(field=IDStr("run_id")) # default ordering

base_select_query = sa.select(
comp_runs.c.project_uuid,
comp_runs.c.iteration,
comp_runs.c.result.label("state"),
comp_runs.c.metadata.label("info"),
comp_runs.c.created.label("submitted_at"),
comp_runs.c.started.label("started_at"),
comp_runs.c.ended.label("ended_at"),
*self._COMPUTATION_RUNS_RPC_GET_COLUMNS
).select_from(
sa.select(
comp_runs.c.project_uuid,
Expand Down Expand Up @@ -286,6 +290,62 @@ async def list_for_user__only_latest_iterations(

return cast(int, total_count), items

async def list_for_user_and_project_all_iterations(
self,
*,
product_name: str,
user_id: UserID,
project_id: ProjectID,
# pagination
offset: int,
limit: int,
# ordering
order_by: OrderBy | None = None,
) -> tuple[int, list[ComputationRunRpcGet]]:
if order_by is None:
order_by = OrderBy(field=IDStr("run_id")) # default ordering

base_select_query = sa.select(
*self._COMPUTATION_RUNS_RPC_GET_COLUMNS,
).where(
(comp_runs.c.user_id == user_id)
& (comp_runs.c.project_uuid == f"{project_id}")
& (
comp_runs.c.metadata["product_name"].astext == product_name
) # <-- NOTE: We might create a separate column for this for fast retrieval
)

# Select total count from base_query
count_query = sa.select(sa.func.count()).select_from(
base_select_query.subquery()
)

# Ordering and pagination
if order_by.direction == OrderDirection.ASC:
list_query = base_select_query.order_by(
sa.asc(getattr(comp_runs.c, order_by.field)), comp_runs.c.run_id
)
else:
list_query = base_select_query.order_by(
desc(getattr(comp_runs.c, order_by.field)), comp_runs.c.run_id
)
list_query = list_query.offset(offset).limit(limit)

async with pass_or_acquire_connection(self.db_engine) as conn:
total_count = await conn.scalar(count_query)

items = [
ComputationRunRpcGet.model_validate(
{
**row,
"state": DB_TO_RUNNING_STATE[row["state"]],
}
)
async for row in await conn.stream(list_query)
]

return cast(int, total_count), items

async def create(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pylint: disable=too-many-positional-arguments

from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from datetime import UTC, datetime, timedelta
from typing import Any

from models_library.api_schemas_directorv2.comp_runs import (
Expand Down Expand Up @@ -66,7 +66,7 @@ async def test_rpc_list_computation_runs_and_tasks(
user=user,
project=proj,
result=RunningState.PENDING,
started=datetime.now(tz=timezone.utc),
started=datetime.now(tz=UTC),
iteration=2,
)
output = await rpc_computations.list_computations_latest_iteration_page(
Expand All @@ -82,8 +82,8 @@ async def test_rpc_list_computation_runs_and_tasks(
user=user,
project=proj,
result=RunningState.SUCCESS,
started=datetime.now(tz=timezone.utc),
ended=datetime.now(tz=timezone.utc),
started=datetime.now(tz=UTC),
ended=datetime.now(tz=UTC),
iteration=3,
)
output = await rpc_computations.list_computations_latest_iteration_page(
Expand All @@ -103,3 +103,105 @@ async def test_rpc_list_computation_runs_and_tasks(
assert output.total == 4
assert isinstance(output, ComputationTaskRpcGetPage)
assert len(output.items) == 4


async def test_rpc_list_computation_runs_with_filtering(
fake_workbench_without_outputs: dict[str, Any],
fake_workbench_adjacency: dict[str, Any],
registered_user: Callable[..., dict[str, Any]],
project: Callable[..., Awaitable[ProjectAtDB]],
create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]],
create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]],
create_comp_run: Callable[..., Awaitable[CompRunsAtDB]],
rpc_client: RabbitMQRPCClient,
):
user = registered_user()

proj_1 = await project(user, workbench=fake_workbench_without_outputs)
await create_pipeline(
project_id=f"{proj_1.uuid}",
dag_adjacency_list=fake_workbench_adjacency,
)
comp_tasks = await create_tasks(
user=user, project=proj_1, state=StateType.PUBLISHED, progress=None
)
comp_runs = await create_comp_run(
user=user, project=proj_1, result=RunningState.PUBLISHED
)

proj_2 = await project(user, workbench=fake_workbench_without_outputs)
await create_pipeline(
project_id=f"{proj_2.uuid}",
dag_adjacency_list=fake_workbench_adjacency,
)
comp_tasks = await create_tasks(
user=user, project=proj_2, state=StateType.SUCCESS, progress=None
)
comp_runs = await create_comp_run(
user=user, project=proj_2, result=RunningState.SUCCESS
)

# Test default behaviour `filter_only_running=False`
output = await rpc_computations.list_computations_latest_iteration_page(
rpc_client, product_name="osparc", user_id=user["id"]
)
assert output.total == 2

# Test filtering
output = await rpc_computations.list_computations_latest_iteration_page(
rpc_client, product_name="osparc", user_id=user["id"], filter_only_running=True
)
assert output.total == 1
assert output.items[0].project_uuid == proj_1.uuid


async def test_rpc_list_computation_runs_history(
fake_workbench_without_outputs: dict[str, Any],
fake_workbench_adjacency: dict[str, Any],
registered_user: Callable[..., dict[str, Any]],
project: Callable[..., Awaitable[ProjectAtDB]],
create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]],
create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]],
create_comp_run: Callable[..., Awaitable[CompRunsAtDB]],
rpc_client: RabbitMQRPCClient,
):
user = registered_user()

proj = await project(user, workbench=fake_workbench_without_outputs)
await create_pipeline(
project_id=f"{proj.uuid}",
dag_adjacency_list=fake_workbench_adjacency,
)
comp_tasks = await create_tasks(
user=user, project=proj, state=StateType.PUBLISHED, progress=None
)
comp_runs_1 = await create_comp_run(
user=user,
project=proj,
result=RunningState.SUCCESS,
started=datetime.now(tz=UTC) - timedelta(minutes=120),
ended=datetime.now(tz=UTC) - timedelta(minutes=100),
iteration=1,
)
comp_runs_2 = await create_comp_run(
user=user,
project=proj,
result=RunningState.SUCCESS,
started=datetime.now(tz=UTC) - timedelta(minutes=90),
ended=datetime.now(tz=UTC) - timedelta(minutes=60),
iteration=2,
)
comp_runs_3 = await create_comp_run(
user=user,
project=proj,
result=RunningState.FAILED,
started=datetime.now(tz=UTC) - timedelta(minutes=50),
ended=datetime.now(tz=UTC),
iteration=3,
)

output = await rpc_computations.list_computations_iterations_page(
rpc_client, product_name="osparc", user_id=user["id"], project_id=proj.uuid
)
assert output.total == 3
assert isinstance(output, ComputationRunRpcGetPage)
Loading
Loading