diff --git a/.github/badges/coverage.json b/.github/badges/coverage.json index e585df0cd..a834ae554 100644 --- a/.github/badges/coverage.json +++ b/.github/badges/coverage.json @@ -1 +1 @@ -{"schemaVersion":1,"label":"coverage","message":"51.03%","color":"red"} +{"schemaVersion":1,"label":"coverage","message":"50.10%","color":"red"} diff --git a/api/dependencies.py b/api/dependencies.py index 9b1f6c935..91c33aa49 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -8,7 +8,7 @@ from api.infrastructure.model import ModelProviderGateway from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository from api.schemas.core.context import RequestContext -from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase +from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase, GetOneProviderUseCase, GetProvidersUseCase from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase from api.use_cases.models import GetModelsUseCase from api.utils.configuration import configuration @@ -98,3 +98,17 @@ def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(ge provider_repository=PostgresProviderRepository(postgres_session=postgres_session), user_info_repository=_user_info_repository(postgres_session), ) + + +def get_one_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneProviderUseCase: + return GetOneProviderUseCase( + provider_repository=PostgresProviderRepository(postgres_session=postgres_session), + user_info_repository=_user_info_repository(postgres_session), + ) + + +def get_providers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetProvidersUseCase: + return GetProvidersUseCase( + provider_repository=PostgresProviderRepository(postgres_session=postgres_session), + user_info_repository=_user_info_repository(postgres_session), + ) diff --git a/api/domain/__init__.py b/api/domain/__init__.py index e69de29bb..07527bfaa 100644 --- a/api/domain/__init__.py +++ b/api/domain/__init__.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from enum import StrEnum +from typing import Generic, TypeVar + + +class SortField(StrEnum): + ID = "id" + NAME = "name" + CREATED = "created" + + +class SortOrder(StrEnum): + ASC = "asc" + DESC = "desc" + + +T = TypeVar("T") + + +@dataclass +class EntitiesPage(Generic[T]): + total: int + data: list[T] diff --git a/api/domain/provider/_providerrepository.py b/api/domain/provider/_providerrepository.py index 44167ece3..7dbcfdd34 100644 --- a/api/domain/provider/_providerrepository.py +++ b/api/domain/provider/_providerrepository.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from api.domain import SortOrder from api.domain.model.entities import Metric -from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderPage, ProviderSortField, ProviderType from api.domain.provider.errors import ProviderAlreadyExistsError @@ -32,3 +33,21 @@ async def delete_provider( provider_id: int, ) -> Provider | None: pass + + @abstractmethod + async def get_one_provider( + self, + provider_id: int, + ) -> Provider | None: + pass + + @abstractmethod + async def get_providers_page( + self, + router_id: int | None, + limit: int, + offset: int, + sort_by: ProviderSortField = ProviderSortField.ID, + sort_order: SortOrder = SortOrder.ASC, + ) -> ProviderPage: + pass diff --git a/api/domain/provider/entities.py b/api/domain/provider/entities.py index 9ffcbe3c5..93e04e4a5 100644 --- a/api/domain/provider/entities.py +++ b/api/domain/provider/entities.py @@ -4,6 +4,7 @@ import pycountry from pydantic import Field, constr +from api.domain import EntitiesPage from api.domain.model.entities import ModelType from api.schemas import BaseModel from api.schemas.core.models import Metric @@ -57,6 +58,15 @@ class ProviderType(str, Enum): } +class ProviderSortField(str, Enum): + ID = "id" + MODEL_NAME = "model_name" + CREATED = "created" + + +ProviderPage = EntitiesPage["Provider"] + + class Provider(BaseModel): object: Literal["provider"] = "provider" id: int = Field(..., description="Provider ID.") # fmt: off diff --git a/api/domain/router/_routerrepository.py b/api/domain/router/_routerrepository.py index efcd71b84..95342c5ce 100644 --- a/api/domain/router/_routerrepository.py +++ b/api/domain/router/_routerrepository.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from api.domain import SortField, SortOrder from api.domain.model import ModelType as RouterType -from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage, RouterSortField, SortOrder +from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError @@ -19,7 +20,7 @@ async def get_routers_page( self, limit: int, offset: int, - sort_by: RouterSortField = RouterSortField.ID, + sort_by: SortField = SortField.ID, sort_order: SortOrder = SortOrder.ASC, ) -> RouterPage: pass diff --git a/api/domain/router/entities.py b/api/domain/router/entities.py index 125ad1504..ca6cf14bc 100644 --- a/api/domain/router/entities.py +++ b/api/domain/router/entities.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass from enum import StrEnum from pydantic import BaseModel, Field +from api.domain import EntitiesPage from api.domain.model import ModelType as RouterType @@ -11,21 +11,7 @@ class RouterLoadBalancingStrategy(StrEnum): LEAST_BUSY = "least_busy" -class RouterSortField(StrEnum): - ID = "id" - NAME = "name" - CREATED = "created" - - -class SortOrder(StrEnum): - ASC = "asc" - DESC = "desc" - - -@dataclass -class RouterPage: - total: int - data: list["Router"] +RouterPage = EntitiesPage["Router"] class Router(BaseModel): diff --git a/api/endpoints/admin/providers.py b/api/endpoints/admin/providers.py index 17876e111..a53d117cb 100644 --- a/api/endpoints/admin/providers.py +++ b/api/endpoints/admin/providers.py @@ -1,13 +1,11 @@ -from typing import Literal - -from fastapi import Body, Depends, Path, Query, Request, Security -from fastapi.responses import JSONResponse, Response +from fastapi import Body, Depends, Path, Request, Security +from fastapi.responses import Response from sqlalchemy.ext.asyncio import AsyncSession from api.endpoints.admin import router from api.helpers._accesscontroller import AccessController from api.helpers.models import ModelRegistry -from api.schemas.admin.providers import Provider, Providers, UpdateProvider +from api.schemas.admin.providers import UpdateProvider from api.schemas.admin.roles import PermissionType from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute @@ -41,56 +39,3 @@ async def update_provider( ) return Response(status_code=204) - - -@router.get( - path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", - dependencies=[Security(dependency=AccessController(permissions=[PermissionType.PROVIDE_MODELS]))], - status_code=200, - response_model=Provider, -) -async def get_provider( - request: Request, - provider: int = Path(description="The ID of the provider to get."), - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> JSONResponse: - """ - Get a model provider by router and provider IDs. - """ - providers = await model_registry.get_providers(router_id=router, provider_id=provider, postgres_session=postgres_session) - provider = providers[0] - - return JSONResponse(status_code=200, content=provider.model_dump()) - - -@router.get( - path=EndpointRoute.ADMIN_PROVIDERS, - dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))], - status_code=200, - response_model=Providers, -) -async def get_providers( - request: Request, - router: int | None = Query(default=None, description="Filter providers by router ID."), - offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."), - limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."), - order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."), - order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."), - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> JSONResponse: - """ - Get all model providers for a router. - """ - providers = await model_registry.get_providers( - router_id=router, - provider_id=None, - postgres_session=postgres_session, - offset=offset, - limit=limit, - order_by=order_by, - order_direction=order_direction, - ) - - return JSONResponse(status_code=200, content=Providers(data=providers).model_dump()) diff --git a/api/infrastructure/fastapi/endpoints/admin/providers.py b/api/infrastructure/fastapi/endpoints/admin/providers.py index 67eb209e8..0878052a4 100644 --- a/api/infrastructure/fastapi/endpoints/admin/providers.py +++ b/api/infrastructure/fastapi/endpoints/admin/providers.py @@ -1,14 +1,21 @@ from contextvars import ContextVar import logging -from typing import Literal from fastapi import Body, Depends, Path, Query, Request, Security -from fastapi.responses import JSONResponse, Response +from fastapi.responses import Response from sqlalchemy.ext.asyncio import AsyncSession -from api.dependencies import create_provider_use_case_factory, delete_provider_use_case_factory, get_request_context +from api.dependencies import ( + create_provider_use_case_factory, + delete_provider_use_case_factory, + get_one_provider_use_case_factory, + get_providers_use_case_factory, + get_request_context, +) +from api.domain import SortOrder from api.domain.model import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError from api.domain.provider import InvalidProviderTypeError, ProviderNotReachableError +from api.domain.provider.entities import ProviderSortField from api.domain.provider.errors import ProviderAlreadyExistsError, ProviderNotFoundError from api.domain.router.errors import RouterNotFoundError from api.domain.userinfo.errors import UserIsNotAdminError @@ -36,6 +43,12 @@ DeleteProviderCommand, DeleteProviderUseCase, DeleteProviderUseCaseSuccess, + GetOneProviderCommand, + GetOneProviderUseCase, + GetOneProviderUseCaseSuccess, + GetProvidersCommand, + GetProvidersUseCase, + GetProvidersUseCaseSuccess, ) from api.utils.dependencies import get_model_registry, get_postgres_session from api.utils.variables import EndpointRoute @@ -60,7 +73,6 @@ ), ) async def create_provider( - request: Request, body: CreateProvider, create_provider_use_case: CreateProviderUseCase = Depends(create_provider_use_case_factory), request_context: ContextVar[RequestContext] = Depends(get_request_context), @@ -98,10 +110,18 @@ async def create_provider( case CreateProviderUseCaseSuccess(created_provider): return CreateProviderResponse.model_validate(created_provider, from_attributes=True) - case InconsistentModelMaxContextLengthError(expected_max_context_length=expected_max_context_length, actual_max_context_length=actual_max_context_length, router_name=router_name): # fmt: off - raise InconsistentModelMaxContextLengthHTTPException(input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name) # fmt: off - case InconsistentModelVectorSizeError(expected_vector_size=expected_vector_size, actual_vector_size=actual_vector_size, router_name=router_name): # fmt: off - raise InconsistentModelVectorSizeHTTPException(input_vector_size=actual_vector_size, model_vector_size=expected_vector_size, model_name=router_name) # fmt: off + case InconsistentModelMaxContextLengthError(expected_max_context_length=expected_max_context_length, + actual_max_context_length=actual_max_context_length, + router_name=router_name): # fmt: off + raise InconsistentModelMaxContextLengthHTTPException(input_max_context_length=actual_max_context_length, + model_max_context_length=expected_max_context_length, + model_name=router_name) # fmt: off + case InconsistentModelVectorSizeError(expected_vector_size=expected_vector_size, + actual_vector_size=actual_vector_size, + router_name=router_name): # fmt: off + raise InconsistentModelVectorSizeHTTPException(input_vector_size=actual_vector_size, + model_vector_size=expected_vector_size, + model_name=router_name) # fmt: off case InvalidProviderTypeError(provider_type=provider_type, router_type=router_type): raise InvalidProviderTypeHTTPException(incorrect_provider_type=provider_type, router_type=router_type) case ProviderNotReachableError(model_name=name): @@ -118,6 +138,7 @@ async def create_provider( path=EndpointRoute.ADMIN_PROVIDERS + "/{provider_id}", dependencies=[Security(dependency=get_current_key)], status_code=200, + responses=get_documentation_responses([NotAdminUserHTTPException, ProviderNotFoundHTTPException]), ) async def delete_provider( provider_id: int = Path(description="The ID of the provider to delete."), @@ -178,21 +199,36 @@ async def update_provider( @router.get( - path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}", + path=EndpointRoute.ADMIN_PROVIDERS + "/{provider_id}", dependencies=[Security(dependency=get_current_key)], status_code=200, - response_model=Provider, + responses=get_documentation_responses([NotAdminUserHTTPException, ProviderNotFoundHTTPException]), ) async def get_provider( - request: Request, - provider: int = Path(description="The ID of the provider to get."), - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> JSONResponse: - providers = await model_registry.get_providers(router_id=None, provider_id=provider, postgres_session=postgres_session) - provider = providers[0] - - return JSONResponse(status_code=200, content=provider.model_dump()) + provider_id: int = Path(description="The ID of the provider to get."), + get_one_provider_use_case: GetOneProviderUseCase = Depends(get_one_provider_use_case_factory), + request_context: ContextVar[RequestContext] = Depends(get_request_context), +) -> Provider: + command = GetOneProviderCommand(user_id=request_context.get().user_id, provider_id=provider_id) + try: + result = await get_one_provider_use_case.execute(command) + except Exception as e: + logger.exception( + "Unexpected error while executing get_one_provider use case", + extra={ + "user_id": command.user_id, + "provider_id": command.provider_id, + "error_type": type(e).__name__, + }, + ) + raise InternalServerHTTPException() + match result: + case GetOneProviderUseCaseSuccess(provider): + return Provider.model_validate(provider, from_attributes=True) + case ProviderNotFoundError(provider_id=not_found_id): + raise ProviderNotFoundHTTPException(not_found_id) + case UserIsNotAdminError(): + raise NotAdminUserHTTPException() @router.get( @@ -200,25 +236,48 @@ async def get_provider( dependencies=[Security(dependency=get_current_key)], status_code=200, response_model=Providers, + responses=get_documentation_responses([NotAdminUserHTTPException]), ) async def get_providers( - request: Request, - router: int | None = Query(default=None, description="Filter providers by router ID."), - offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."), - limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."), - order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."), - order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."), - postgres_session: AsyncSession = Depends(get_postgres_session), - model_registry: ModelRegistry = Depends(get_model_registry), -) -> JSONResponse: - providers = await model_registry.get_providers( - router_id=router, - provider_id=None, - postgres_session=postgres_session, + router_id: int | None = Query(default=None, description="Filter providers by router ID."), + offset: int = Query(default=0, ge=0, description="Number of providers to skip."), + limit: int = Query(default=10, ge=1, le=100, description="Maximum number of providers to return."), + sort_by: ProviderSortField = Query(default=ProviderSortField.ID, description="Field to sort by."), + sort_order: SortOrder = Query(default=SortOrder.ASC, description="Sort order."), + request_context: ContextVar[RequestContext] = Depends(get_request_context), + get_providers_use_case: GetProvidersUseCase = Depends(get_providers_use_case_factory), +) -> Providers: + command = GetProvidersCommand( + router_id=router_id, + user_id=request_context.get().user_id, offset=offset, limit=limit, - order_by=order_by, - order_direction=order_direction, + sort_by=sort_by, + sort_order=sort_order, ) - - return JSONResponse(status_code=200, content=Providers(data=providers).model_dump()) + try: + result = await get_providers_use_case.execute(command) + except Exception as e: + logger.exception( + "Unexpected error while executing get_providers use case", + extra={ + "user_id": command.user_id, + "router_id": router_id, + "offset": command.offset, + "limit": command.limit, + "sort_by": command.sort_by, + "sort_order": command.sort_order, + "error_type": type(e).__name__, + }, + ) + raise InternalServerHTTPException() + match result: + case GetProvidersUseCaseSuccess(page=providers_page): + return Providers( + total=providers_page.total, + offset=offset, + limit=limit, + data=[Provider.model_validate(provider, from_attributes=True) for provider in providers_page.data], + ) + case UserIsNotAdminError(): + raise NotAdminUserHTTPException() diff --git a/api/infrastructure/fastapi/endpoints/admin/routers.py b/api/infrastructure/fastapi/endpoints/admin/routers.py index 8e47a5b4b..3f3fc17fd 100644 --- a/api/infrastructure/fastapi/endpoints/admin/routers.py +++ b/api/infrastructure/fastapi/endpoints/admin/routers.py @@ -11,7 +11,7 @@ get_routers_use_case_factory, update_router_use_case_factory, ) -from api.domain.router.entities import RouterSortField, SortOrder +from api.domain import SortField, SortOrder from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError, RouterNotFoundError from api.domain.userinfo.errors import UserIsNotAdminError from api.infrastructure.fastapi.access import get_current_key @@ -52,11 +52,13 @@ path=EndpointRoute.ADMIN_ROUTERS, dependencies=[Security(dependency=get_current_key)], status_code=201, - responses=get_documentation_responses([ - RouterAliasAlreadyExistsHTTPException, - RouterAlreadyExistsHTTPException, - NotAdminUserHTTPException, - ]), + responses=get_documentation_responses( + [ + RouterAliasAlreadyExistsHTTPException, + RouterAlreadyExistsHTTPException, + NotAdminUserHTTPException, + ] + ), ) async def create_router( body: CreateRouterBody = Body(description="The router creation request."), @@ -141,7 +143,7 @@ async def get_router( async def get_routers( offset: int = Query(default=0, ge=0, description="Number of routers to skip."), limit: int = Query(default=10, ge=1, le=100, description="Maximum number of routers to return."), - sort_by: RouterSortField = Query(default=RouterSortField.ID, description="Field to sort by."), + sort_by: SortField = Query(default=SortField.ID, description="Field to sort by."), sort_order: SortOrder = Query(default=SortOrder.ASC, description="Sort order."), get_routers_use_case: GetRoutersUseCase = Depends(get_routers_use_case_factory), request_context: ContextVar[RequestContext] = Depends(get_request_context), @@ -165,12 +167,12 @@ async def get_routers( ) raise InternalServerHTTPException() match result: - case GetRoutersUseCaseSuccess(routers, total): + case GetRoutersUseCaseSuccess(router_page=router_page): return Routers( - total=total, + total=router_page.total, offset=offset, limit=limit, - data=[RouterResponse.model_validate(r, from_attributes=True) for r in routers], + data=[RouterResponse.model_validate(r, from_attributes=True) for r in router_page.data], ) case UserIsNotAdminError(): raise NotAdminUserHTTPException() @@ -216,12 +218,14 @@ async def delete_router( @router.patch( path=EndpointRoute.ADMIN_ROUTERS + "/{router_id}", dependencies=[Security(dependency=get_current_key)], - responses=get_documentation_responses([ - RouterNotFoundHTTPException, - NotAdminUserHTTPException, - RouterAliasAlreadyExistsHTTPException, - RouterAlreadyExistsHTTPException, - ]), + responses=get_documentation_responses( + [ + RouterNotFoundHTTPException, + NotAdminUserHTTPException, + RouterAliasAlreadyExistsHTTPException, + RouterAlreadyExistsHTTPException, + ] + ), status_code=200, ) async def update_router( diff --git a/api/infrastructure/fastapi/schemas/providers.py b/api/infrastructure/fastapi/schemas/providers.py index e92d22405..59532d6f5 100644 --- a/api/infrastructure/fastapi/schemas/providers.py +++ b/api/infrastructure/fastapi/schemas/providers.py @@ -112,4 +112,7 @@ class Provider(BaseModel): class Providers(BaseModel): object: Literal["list"] = "list" + total: int + offset: int + limit: int data: list[Provider] diff --git a/api/infrastructure/postgres/_postgresproviderrepository.py b/api/infrastructure/postgres/_postgresproviderrepository.py index d9f25a9e1..872d712d9 100644 --- a/api/infrastructure/postgres/_postgresproviderrepository.py +++ b/api/infrastructure/postgres/_postgresproviderrepository.py @@ -1,10 +1,12 @@ -from sqlalchemy import delete, insert, select +from sqlalchemy import asc, delete, desc, func, insert, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from api.domain import SortOrder +from api.domain.key.entities import MASTER_USER_ID from api.domain.model.entities import Metric from api.domain.provider import ProviderRepository -from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderPage, ProviderSortField, ProviderType from api.domain.provider.errors import ProviderAlreadyExistsError from api.sql.models import Provider as ProviderTable @@ -13,11 +15,40 @@ class PostgresProviderRepository(ProviderRepository): def __init__(self, postgres_session: AsyncSession): self.postgres_session = postgres_session + async def get_providers_page( + self, router_id: int | None, limit: int, offset: int, sort_by: ProviderSortField = ProviderSortField.ID, sort_order: SortOrder = SortOrder.ASC + ) -> ProviderPage: + select_query = select(ProviderTable) + count_query = select(func.count()).select_from(ProviderTable) + + if router_id is not None: + select_query = select_query.where(ProviderTable.router_id == router_id) + count_query = count_query.where(ProviderTable.router_id == router_id) + + total = (await self.postgres_session.execute(count_query)).scalar_one() + sort_column = getattr(ProviderTable, sort_by.value) + sort_order_clause = asc(sort_column) if sort_order == SortOrder.ASC else desc(sort_column) + + providers_query = select_query.order_by(sort_order_clause).limit(limit).offset(offset) + + rows = (await self.postgres_session.execute(providers_query)).scalars().all() + + return ProviderPage(total=total, data=[self._row_to_provider(row) for row in rows]) + + async def get_one_provider(self, provider_id: int) -> Provider | None: + select_query = select(ProviderTable).where(ProviderTable.id == provider_id) + + result = await self.postgres_session.execute(select_query) + row = result.scalar_one_or_none() + if row is None: + return None + return self._row_to_provider(row) + @staticmethod def _row_to_provider(row) -> Provider: return Provider( router_id=row.router_id, - user_id=row.user_id, + user_id=MASTER_USER_ID if row.user_id is None else row.user_id, type=row.type, url=row.url, key=row.key, @@ -31,6 +62,8 @@ def _row_to_provider(row) -> Provider: max_context_length=row.max_context_length, vector_size=row.vector_size, id=row.id, + created=int(row.created.timestamp()), + updated=int(row.updated.timestamp()), ) async def create_provider( diff --git a/api/infrastructure/postgres/_postgresrouterrepository.py b/api/infrastructure/postgres/_postgresrouterrepository.py index 8dfeadd1d..6945a191b 100644 --- a/api/infrastructure/postgres/_postgresrouterrepository.py +++ b/api/infrastructure/postgres/_postgresrouterrepository.py @@ -2,10 +2,11 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from api.domain import SortField, SortOrder from api.domain.key.entities import MASTER_USER_ID from api.domain.model import ModelType as RouterType from api.domain.router import RouterRepository -from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage, RouterSortField, SortOrder +from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.sql.models import Organization as OrganizationTable from api.sql.models import Provider as ProviderTable @@ -90,11 +91,14 @@ async def get_routers_page( self, limit: int, offset: int, - sort_by: RouterSortField = RouterSortField.ID, + sort_by: SortField = SortField.ID, sort_order: SortOrder = SortOrder.ASC, ) -> RouterPage: distinct_routers = (self._select_routers_with_provider_stats().distinct(RouterTable.id).order_by(RouterTable.id, ProviderTable.id)).subquery() + count_query = select(func.count()).select_from(distinct_routers) + total = (await self.postgres_session.execute(count_query)).scalar_one() + sort_column = distinct_routers.c[sort_by.value] sort_order_clause = asc(sort_column) if sort_order == SortOrder.ASC else desc(sort_column) @@ -102,7 +106,6 @@ async def get_routers_page( result = await self.postgres_session.execute(routers_query) rows = result.all() - total = rows[0].total if rows else 0 router_ids = [row.id for row in rows] aliases_by_router = await self.get_aliases_grouped_by_router(router_ids) routers = [self._row_to_router_with_aliases(row, aliases_by_router.get(row.id, [])) for row in rows] @@ -133,35 +136,34 @@ async def create_router( aliases = aliases or [] try: - async with self.postgres_session.begin_nested(): - insert_router_query = ( - insert(RouterTable) - .values( - user_id=db_user_id, - name=name, - type=router_type.value, - load_balancing_strategy=load_balancing_strategy.value, - cost_prompt_tokens=cost_prompt_tokens, - cost_completion_tokens=cost_completion_tokens, - ) - .returning( - RouterTable.id, - RouterTable.name, - RouterTable.user_id, - RouterTable.type, - RouterTable.load_balancing_strategy, - RouterTable.cost_prompt_tokens, - RouterTable.cost_completion_tokens, - cast(func.extract("epoch", RouterTable.created), Integer).label("created"), - cast(func.extract("epoch", RouterTable.updated), Integer).label("updated"), - ) + insert_router_query = ( + insert(RouterTable) + .values( + user_id=db_user_id, + name=name, + type=router_type.value, + load_balancing_strategy=load_balancing_strategy.value, + cost_prompt_tokens=cost_prompt_tokens, + cost_completion_tokens=cost_completion_tokens, + ) + .returning( + RouterTable.id, + RouterTable.name, + RouterTable.user_id, + RouterTable.type, + RouterTable.load_balancing_strategy, + RouterTable.cost_prompt_tokens, + RouterTable.cost_completion_tokens, + cast(func.extract("epoch", RouterTable.created), Integer).label("created"), + cast(func.extract("epoch", RouterTable.updated), Integer).label("updated"), ) - result = await self.postgres_session.execute(insert_router_query) - row = result.one() + ) + result = await self.postgres_session.execute(insert_router_query) + row = result.one() - if aliases: - aliases_to_insert = [{"value": alias, "router_id": row.id} for alias in aliases] - await self.postgres_session.execute(insert(RouterAliasTable), aliases_to_insert) + if aliases: + aliases_to_insert = [{"value": alias, "router_id": row.id} for alias in aliases] + await self.postgres_session.execute(insert(RouterAliasTable), aliases_to_insert) except IntegrityError as e: if "router_name_key" in str(e.orig): diff --git a/api/schemas/admin/providers.py b/api/schemas/admin/providers.py index 6756d5f8c..069e8bd35 100644 --- a/api/schemas/admin/providers.py +++ b/api/schemas/admin/providers.py @@ -77,4 +77,25 @@ def validate_model(self): return self -from api.infrastructure.fastapi.schemas.providers import Provider, Providers +class Provider(BaseModel): + object: Literal["provider"] = "provider" + id: int = Field(..., description="Provider ID.") # fmt: off + router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off + user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off + type: ProviderType = Field(..., description="Provider type.") # fmt: off + url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off + key: str | None = Field(description="Provider API key.") # fmt: off + timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off + model_name: str = Field(..., description="Model name from the model provider.") # fmt: off + model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off + model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off + qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off + qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off + created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off + updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off + + +class Providers(BaseModel): + object: Literal["list"] = "list" + data: list[Provider] diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 41a77f56f..8defd61d3 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -138,7 +138,7 @@ def restart_savepoint(sess, trans): await transaction.rollback() -@pytest_asyncio.fixture(scope="session") +@pytest.fixture(scope="session") def model_registry(): return ModelRegistry( app_title="test", diff --git a/api/tests/integration/endpoints/admin/provider/test_create_provider.py b/api/tests/integration/endpoints/admin/provider/test_create_provider.py index 4afd5e333..7e00ba94d 100644 --- a/api/tests/integration/endpoints/admin/provider/test_create_provider.py +++ b/api/tests/integration/endpoints/admin/provider/test_create_provider.py @@ -62,11 +62,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(create_provider_use_case_factory, None) - @respx.mock async def test_happy_path(self, client: AsyncClient, db_session): router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) diff --git a/api/tests/integration/endpoints/admin/provider/test_delete_provider.py b/api/tests/integration/endpoints/admin/provider/test_delete_provider.py index 5cfa2901f..1e8258632 100644 --- a/api/tests/integration/endpoints/admin/provider/test_delete_provider.py +++ b/api/tests/integration/endpoints/admin/provider/test_delete_provider.py @@ -21,11 +21,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(delete_provider_use_case_factory, None) - async def test_happy_path(self, client: AsyncClient, db_session): router = RouterSQLFactory(user=self.admin_user) provider = ProviderSQLFactory(router=router, user=self.admin_user) diff --git a/api/tests/integration/endpoints/admin/provider/test_get_one_provider.py b/api/tests/integration/endpoints/admin/provider/test_get_one_provider.py new file mode 100644 index 000000000..a5bba5d08 --- /dev/null +++ b/api/tests/integration/endpoints/admin/provider/test_get_one_provider.py @@ -0,0 +1,78 @@ +from unittest.mock import AsyncMock + +from httpx import AsyncClient +import pytest +import pytest_asyncio + +from api.dependencies import get_one_provider_use_case_factory +from api.domain.provider.errors import ProviderNotFoundError +from api.domain.userinfo.errors import UserIsNotAdminError +from api.tests.helpers import create_token +from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory +from api.utils.variables import EndpointRoute + +URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" + + +@pytest.mark.asyncio(loop_scope="session") +class TestGetProvider: + @pytest_asyncio.fixture(autouse=True) + async def setup(self, db_session): + self.admin_user = UserSQLFactory(admin_user=True) + self.token = await create_token(db_session, name="admin_token", user=self.admin_user) + + async def test_happy_path(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user) + provider = ProviderSQLFactory(router=router, user=self.admin_user) + await db_session.flush() + + response = await client.get( + url=f"{URL}/{provider.id}", + headers={"Authorization": f"Bearer {self.token.token}"}, + ) + + assert response.status_code == 200, response.text + data = response.json() + assert data["id"] == provider.id + assert data["object"] == "provider" + + @pytest.mark.parametrize( + "use_case_result,expected_status,expected_detail", + [ + ( + ProviderNotFoundError(provider_id=1), + 404, + "Model provider 1 not found.", + ), + ( + UserIsNotAdminError(), + 403, + "User has no admin rights.", + ), + ], + ) + async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail): + mock_use_case = AsyncMock() + mock_use_case.execute.return_value = use_case_result + app.dependency_overrides[get_one_provider_use_case_factory] = lambda: mock_use_case + + response = await client.get( + url=f"{URL}/1", + headers={"Authorization": f"Bearer {self.token.token}"}, + ) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail + + @pytest.mark.parametrize( + "headers,expected_status,expected_detail", + [ + ({}, 401, "Not authenticated"), + ({"Authorization": "Bearer invalid-token"}, 403, "Invalid API key."), + ], + ) + async def test_auth(self, client: AsyncClient, headers, expected_status, expected_detail): + response = await client.get(url=f"{URL}/1", headers=headers) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail diff --git a/api/tests/integration/endpoints/admin/provider/test_get_providers.py b/api/tests/integration/endpoints/admin/provider/test_get_providers.py new file mode 100644 index 000000000..6f26a38dd --- /dev/null +++ b/api/tests/integration/endpoints/admin/provider/test_get_providers.py @@ -0,0 +1,106 @@ +from unittest.mock import AsyncMock + +from httpx import AsyncClient +import pytest +import pytest_asyncio + +from api.dependencies import get_providers_use_case_factory +from api.domain.userinfo.errors import UserIsNotAdminError +from api.tests.helpers import create_token +from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory +from api.utils.variables import EndpointRoute + +URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}" + + +@pytest.mark.asyncio(loop_scope="session") +class TestGetProviders: + @pytest_asyncio.fixture(autouse=True) + async def setup(self, db_session): + self.admin_user = UserSQLFactory(admin_user=True) + self.token = await create_token(db_session, name="admin_token", user=self.admin_user) + + async def test_happy_path_without_params(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user) + for i in range(1, 8): + ProviderSQLFactory(router=router, model_name=f"model_{i}") + await db_session.flush() + + response = await client.get( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + ) + + assert response.status_code == 200, response.text + data = response.json() + assert data["object"] == "list" + assert data["total"] == 7 + assert data["offset"] == 0 + assert data["limit"] == 10 + assert len(data["data"]) == 7 + + async def test_happy_path_with_params(self, client: AsyncClient, db_session): + router = RouterSQLFactory(user=self.admin_user) + for i in range(1, 8): + ProviderSQLFactory(router=router, model_name=f"model_{i}") + await db_session.flush() + + response = await client.get( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + params={"offset": 3, "limit": 3}, + ) + + assert response.status_code == 200, response.text + data = response.json() + assert data["object"] == "list" + assert data["total"] == 7 + assert data["offset"] == 3 + assert data["limit"] == 3 + assert len(data["data"]) == 3 + + async def test_pagination_limit_should_be_less_than_100(self, client: AsyncClient): + response = await client.get( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + params={"offset": 0, "limit": 101}, + ) + + assert response.status_code == 422, response.text + assert response.json().get("detail")[0]["msg"] == "Input should be less than or equal to 100" + + @pytest.mark.parametrize( + "use_case_result,expected_status,expected_detail", + [ + ( + UserIsNotAdminError(), + 403, + "User has no admin rights.", + ), + ], + ) + async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail): + mock_use_case = AsyncMock() + mock_use_case.execute.return_value = use_case_result + app.dependency_overrides[get_providers_use_case_factory] = lambda: mock_use_case + + response = await client.get( + url=URL, + headers={"Authorization": f"Bearer {self.token.token}"}, + ) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail + + @pytest.mark.parametrize( + "headers,expected_status,expected_detail", + [ + ({}, 401, "Not authenticated"), + ({"Authorization": "Bearer invalid-token"}, 403, "Invalid API key."), + ], + ) + async def test_auth(self, client: AsyncClient, headers, expected_status, expected_detail): + response = await client.get(url=URL, headers=headers) + + assert response.status_code == expected_status + assert response.json().get("detail") == expected_detail diff --git a/api/tests/integration/endpoints/admin/router/test_create_router.py b/api/tests/integration/endpoints/admin/router/test_create_router.py index 7fd4c0685..27adc7699 100644 --- a/api/tests/integration/endpoints/admin/router/test_create_router.py +++ b/api/tests/integration/endpoints/admin/router/test_create_router.py @@ -34,11 +34,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(create_router_use_case_factory, None) - async def test_happy_path(self, client: AsyncClient, db_session): await db_session.flush() diff --git a/api/tests/integration/endpoints/admin/router/test_delete_router.py b/api/tests/integration/endpoints/admin/router/test_delete_router.py index 16e8ea0b6..8dc750509 100644 --- a/api/tests/integration/endpoints/admin/router/test_delete_router.py +++ b/api/tests/integration/endpoints/admin/router/test_delete_router.py @@ -21,11 +21,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(delete_router_use_case_factory, None) - async def test_happy_path(self, client: AsyncClient, db_session): router = RouterSQLFactory(user=self.admin_user) await db_session.flush() diff --git a/api/tests/integration/endpoints/admin/router/test_get_router.py b/api/tests/integration/endpoints/admin/router/test_get_router.py index d2f62ff5d..e2e850f27 100644 --- a/api/tests/integration/endpoints/admin/router/test_get_router.py +++ b/api/tests/integration/endpoints/admin/router/test_get_router.py @@ -22,11 +22,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(get_one_router_use_case_factory, None) - async def test_happy_path(self, client: AsyncClient, db_session): router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION) await db_session.flush() diff --git a/api/tests/integration/endpoints/admin/router/test_get_routers.py b/api/tests/integration/endpoints/admin/router/test_get_routers.py index 89e3edd01..1d3b7f6fc 100644 --- a/api/tests/integration/endpoints/admin/router/test_get_routers.py +++ b/api/tests/integration/endpoints/admin/router/test_get_routers.py @@ -20,11 +20,6 @@ async def setup(self, db_session): self.admin_user = UserSQLFactory(admin_user=True) self.token = await create_token(db_session, name="admin_token", user=self.admin_user) - @pytest_asyncio.fixture(autouse=True) - async def cleanup_overrides(self, app): - yield - app.dependency_overrides.pop(get_routers_use_case_factory, None) - async def test_happy_path_without_params(self, client: AsyncClient, db_session): RouterSQLFactory(user=self.admin_user, name="router_1") RouterSQLFactory(user=self.admin_user, name="router_2") @@ -56,6 +51,7 @@ async def test_happy_path_with_params(self, client: AsyncClient, db_session): RouterSQLFactory(user=self.admin_user, name="router_5") RouterSQLFactory(user=self.admin_user, name="router_6") RouterSQLFactory(user=self.admin_user, name="router_7") + expected_routers_ordered_by_name = ["router_4", "router_5", "router_6"] await db_session.flush() response = await client.get( @@ -71,7 +67,7 @@ async def test_happy_path_with_params(self, client: AsyncClient, db_session): assert data["offset"] == 3 assert data["limit"] == 3 assert len(data["data"]) == 3 - assert returned_names == ["router_4", "router_5", "router_6"] + assert returned_names == expected_routers_ordered_by_name async def test_pagination_limit_should_be_less_than_100(self, client: AsyncClient, db_session): response = await client.get( diff --git a/api/tests/integration/endpoints/test_models.py b/api/tests/integration/endpoints/test_models.py index ec674a485..c086d4411 100644 --- a/api/tests/integration/endpoints/test_models.py +++ b/api/tests/integration/endpoints/test_models.py @@ -51,9 +51,9 @@ async def test_get_models_happy_path(self, client: AsyncClient, db_session): LimitSQLFactory(role=user_with_routers.role, router=router_1) LimitSQLFactory(role=user_with_routers.role, router=router_2) + await db_session.flush() user_1_token = await create_token(db_session, name="my_token", user=user_with_routers) response = await client.get(url=f"/v1{EndpointRoute.MODELS}", headers={"Authorization": f"Bearer {user_1_token.token}"}) - await db_session.flush() assert response.status_code == 200, f"error: retrieve models ({response.status_code})" models = Models(data=[Model(**model) for model in response.json()["data"]]) assert isinstance(models, Models) diff --git a/api/tests/integration/factories.py b/api/tests/integration/factories.py index 5bc3948fb..692752841 100644 --- a/api/tests/integration/factories.py +++ b/api/tests/integration/factories.py @@ -100,7 +100,7 @@ def email(self): fake = Faker("fr_FR") domain = fake.free_email_domain() clean_name = name_ascii.lower().replace(" ", ".") - return f"{clean_name}@e{domain}" + return f"{clean_name}@{domain}" class Params: admin_user = factory.Trait(role=factory.SubFactory(RoleSQLFactory, admin=True), priority=10) @@ -208,8 +208,6 @@ class Meta: sqlalchemy_session_persistence = "flush" id = None - role_id = factory.SelfAttribute("role.id") - router_id = factory.SelfAttribute("router.id") type = fuzzy.FuzzyChoice([LimitType.TPM, LimitType.TPD, LimitType.RPM, LimitType.RPD]) value = fuzzy.FuzzyInteger(100, 10000) created = factory.LazyFunction(datetime.now) diff --git a/api/tests/integration/postgres/test_postgresproviderrepository.py b/api/tests/integration/postgres/test_postgresproviderrepository.py index 563735d71..666495eb0 100644 --- a/api/tests/integration/postgres/test_postgresproviderrepository.py +++ b/api/tests/integration/postgres/test_postgresproviderrepository.py @@ -1,8 +1,13 @@ +from datetime import datetime, timedelta + import pytest from sqlalchemy import select +from api.domain import SortOrder +from api.domain.key.entities import MASTER_USER_ID from api.domain.model.entities import Metric, ModelType from api.domain.provider import Provider, ProviderAlreadyExistsError, ProviderCarbonFootprintZone, ProviderType +from api.domain.provider.entities import ProviderSortField from api.infrastructure.postgres import PostgresProviderRepository from api.sql.models import Provider as ProviderTable from api.tests.integration.factories import ( @@ -87,6 +92,159 @@ async def test_create_provider_should_return_provider_already_exists_when_same_u assert result.model_name == "duplicate-provider" +@pytest.mark.asyncio(loop_scope="session") +class TestGetOneProvider: + async def test_get_one_provider_should_return_provider_when_it_exists(self, repository, db_session): + # Arrange + provider = ProviderSQLFactory(type=ProviderType.ALBERT, url="http://test.com/", model_name="target-provider", qos_metric=None) + await db_session.flush() + + # Act + result = await repository.get_one_provider(provider.id) + + # Assert + assert isinstance(result, Provider) + assert result.id == provider.id + assert result.router_id == provider.router_id + assert result.user_id == provider.user_id + assert result.type == ProviderType.ALBERT + assert result.url == "http://test.com/" + assert result.model_name == "target-provider" + assert result.max_context_length == provider.max_context_length + assert result.vector_size == provider.vector_size + + async def test_get_one_provider_should_return_none_when_provider_does_not_exist(self, repository, db_session): + # Act + result = await repository.get_one_provider(provider_id=999999) + + # Assert + assert result is None + + async def test_get_one_provider_should_map_null_user_id_to_master_user_id(self, repository, db_session): + # Arrange + router = RouterSQLFactory(user=None) + provider = ProviderSQLFactory(router=router, user=None) + await db_session.flush() + + # Act + result = await repository.get_one_provider(provider.id) + + # Assert + assert isinstance(result, Provider) + assert result.user_id == MASTER_USER_ID + + +@pytest.mark.asyncio(loop_scope="session") +class TestGetProvidersPage: + async def test_returns_correct_page_with_limit_and_offset(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + ProviderSQLFactory(router=router, model_name="provider_a") + ProviderSQLFactory(router=router, model_name="provider_b") + ProviderSQLFactory(router=router, model_name="provider_c") + await db_session.flush() + + # Act + result = await repository.get_providers_page(router_id=None, limit=2, offset=0, sort_by=ProviderSortField.ID, sort_order=SortOrder.ASC) + + # Assert + assert result.total == 3 + assert len(result.data) == 2 + + async def test_total_is_consistent_across_pages(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + for i in range(6): + ProviderSQLFactory(router=router, model_name=f"provider_{i}") + await db_session.flush() + + # Act + first_page = await repository.get_providers_page(router_id=None, limit=4, offset=0) + second_page = await repository.get_providers_page(router_id=None, limit=4, offset=4) + + # Assert + assert first_page.total == second_page.total + assert first_page.total == 6 + assert len(second_page.data) == 2 + + async def test_sort_by_id_asc(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + ProviderSQLFactory(id=4003, router=router, model_name="provider_c") + ProviderSQLFactory(id=4001, router=router, model_name="provider_a") + ProviderSQLFactory(id=4002, router=router, model_name="provider_b") + await db_session.flush() + + # Act + result = await repository.get_providers_page(router_id=None, limit=10, offset=0, sort_by=ProviderSortField.ID, sort_order=SortOrder.ASC) + + # Assert + returned_ids = [p.id for p in result.data] + assert returned_ids == [4001, 4002, 4003] + + async def test_sort_by_model_name_asc(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + ProviderSQLFactory(router=router, model_name="provider_c") + ProviderSQLFactory(router=router, model_name="provider_a") + ProviderSQLFactory(router=router, model_name="provider_b") + await db_session.flush() + + # Act + result = await repository.get_providers_page( + router_id=None, limit=10, offset=0, sort_by=ProviderSortField.MODEL_NAME, sort_order=SortOrder.ASC + ) + + # Assert + returned_names = [p.model_name for p in result.data] + assert returned_names == ["provider_a", "provider_b", "provider_c"] + + async def test_sort_by_created_date_desc(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + ProviderSQLFactory(router=router, model_name="oldest", created=datetime.now() - timedelta(days=10)) + ProviderSQLFactory(router=router, model_name="newest", created=datetime.now()) + ProviderSQLFactory(router=router, model_name="middle", created=datetime.now() - timedelta(hours=1)) + await db_session.flush() + + # Act + result = await repository.get_providers_page(router_id=None, limit=10, offset=0, sort_by=ProviderSortField.CREATED, sort_order=SortOrder.DESC) + + # Assert + returned_names = [p.model_name for p in result.data] + assert returned_names == ["newest", "middle", "oldest"] + + async def test_returns_empty_page_when_offset_exceeds_total(self, repository, db_session): + # Arrange + router = RouterSQLFactory() + ProviderSQLFactory(router=router) + await db_session.flush() + + # Act + result = await repository.get_providers_page(router_id=None, limit=10, offset=100) + + # Assert + assert result.data == [] + assert result.total == 1 + + async def test_filter_by_router_id_returns_only_providers_for_that_router(self, repository, db_session): + # Arrange + router_1 = RouterSQLFactory() + router_2 = RouterSQLFactory() + ProviderSQLFactory(router=router_1, model_name="provider_r1_1") + ProviderSQLFactory(router=router_1, model_name="provider_r1_2") + ProviderSQLFactory(router=router_2, model_name="provider_r2_1") + await db_session.flush() + + # Act + result = await repository.get_providers_page(router_id=router_1.id, limit=10, offset=0) + + # Assert + assert result.total == 2 + assert len(result.data) == 2 + assert all(provider.router_id == router_1.id for provider in result.data) + + @pytest.mark.asyncio(loop_scope="session") class TestDeleteProvider: async def test_delete_provider_should_return_the_deleted_provider_when_provider_exists(self, repository, db_session): diff --git a/api/tests/integration/postgres/test_postgresrouterrepository.py b/api/tests/integration/postgres/test_postgresrouterrepository.py index eb97abdeb..703c4529f 100644 --- a/api/tests/integration/postgres/test_postgresrouterrepository.py +++ b/api/tests/integration/postgres/test_postgresrouterrepository.py @@ -3,9 +3,10 @@ import pytest from sqlalchemy import select +from api.domain import SortField, SortOrder from api.domain.key.entities import MASTER_USER_ID from api.domain.model import ModelType as RouterType -from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterSortField, SortOrder +from api.domain.router.entities import Router, RouterLoadBalancingStrategy from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError from api.infrastructure.postgres import PostgresRouterRepository from api.sql.models import Provider as ProviderTable @@ -73,8 +74,8 @@ async def test_get_all_routers_should_return_all_routers(self, repository, db_se assert len(result_routers) == 3 router_names = {r.name for r in result_routers} assert router_names == {router_1.name, router_2.name, router_3.name} + result_router_1 = next(r for r in result_routers if r.name == router_1.name) - result_router_1 = result_routers[0] first_provider_router_1 = router_1.provider[0] assert result_router_1.type == RouterType.TEXT_GENERATION assert result_router_1.providers == 2 @@ -307,7 +308,7 @@ async def test_returns_correct_page_with_limit_and_offset(self, repository, db_s await db_session.flush() # Act - result = await repository.get_routers_page(limit=2, offset=0, sort_by=RouterSortField.NAME, sort_order=SortOrder.ASC) + result = await repository.get_routers_page(limit=2, offset=0, sort_by=SortField.NAME, sort_order=SortOrder.ASC) # Assert assert result.total == 3 @@ -332,6 +333,7 @@ async def test_total_is_consistent_across_pages(self, repository, db_session): # Assert assert first_page.total == second_page.total + assert first_page.total == 6 assert len(second_page.data) == 2 async def test_sort_by_name_asc(self, repository, db_session): @@ -343,7 +345,7 @@ async def test_sort_by_name_asc(self, repository, db_session): await db_session.flush() # Act - result = await repository.get_routers_page(limit=10, offset=0, sort_by=RouterSortField.NAME, sort_order=SortOrder.ASC) + result = await repository.get_routers_page(limit=10, offset=0, sort_by=SortField.NAME, sort_order=SortOrder.ASC) # Assert returned_names = [r.name for r in result.data] @@ -358,7 +360,7 @@ async def test_sort_by_name_desc(self, repository, db_session): await db_session.flush() # Act - result = await repository.get_routers_page(limit=10, offset=0, sort_by=RouterSortField.NAME, sort_order=SortOrder.DESC) + result = await repository.get_routers_page(limit=10, offset=0, sort_by=SortField.NAME, sort_order=SortOrder.DESC) # Assert returned_names = [r.name for r in result.data] @@ -376,6 +378,7 @@ async def test_returns_empty_page_when_offset_exceeds_total(self, repository, db # Assert assert result.data == [] + assert result.total == 1 async def test_sort_by_id_asc(self, repository, db_session): # Arrange @@ -386,7 +389,7 @@ async def test_sort_by_id_asc(self, repository, db_session): await db_session.flush() # Act - result = await repository.get_routers_page(limit=10, offset=0, sort_by=RouterSortField.ID, sort_order=SortOrder.ASC) + result = await repository.get_routers_page(limit=10, offset=0, sort_by=SortField.ID, sort_order=SortOrder.ASC) # Assert returned_ids = [r.id for r in result.data] @@ -401,7 +404,7 @@ async def test_sort_by_created_date_desc(self, repository, db_session): await db_session.flush() # Act - result = await repository.get_routers_page(limit=10, offset=0, sort_by=RouterSortField.CREATED, sort_order=SortOrder.DESC) + result = await repository.get_routers_page(limit=10, offset=0, sort_by=SortField.CREATED, sort_order=SortOrder.DESC) # Assert returned_names = [r.name for r in result.data] diff --git a/api/tests/integration/test_createapp.py b/api/tests/integration/test_createapp.py index 244395b31..0669808e7 100644 --- a/api/tests/integration/test_createapp.py +++ b/api/tests/integration/test_createapp.py @@ -1,17 +1,18 @@ from collections.abc import AsyncGenerator -from types import SimpleNamespace from httpx import ASGITransport, AsyncClient import pytest +import pytest_asyncio from api.app import create_app +from api.schemas.core.configuration import Configuration, Dependencies, Settings from api.utils.variables import EndpointRoute, RouterName @pytest.fixture(scope="class") -def test_configuration(): - return SimpleNamespace( - settings=SimpleNamespace( +def test_configuration() -> Configuration: + return Configuration.model_construct( + settings=Settings.model_construct( app_title="test", swagger_summary=None, swagger_version="0.0.0", @@ -27,11 +28,11 @@ def test_configuration(): hidden_routers=[RouterName.MODELS], monitoring_prometheus_enabled=False, ), - dependencies=SimpleNamespace(sentry=None), + dependencies=Dependencies.model_construct(sentry=None), ) -@pytest.fixture(scope="class") +@pytest_asyncio.fixture(scope="class") async def client(test_configuration) -> AsyncGenerator[AsyncClient, None]: app = create_app(test_configuration, skip_lifespan=True) @@ -44,14 +45,14 @@ async def client(test_configuration) -> AsyncGenerator[AsyncClient, None]: @pytest.mark.asyncio(loop_scope="session") class TestCreateApp: - async def test_reach_swagger_with_non_default_url_configuration_is_reachable(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_reach_swagger_with_non_default_url_configuration_is_reachable(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url=test_configuration.settings.swagger_docs_url) # Assert assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}" - async def test_redoc_with_non_default_url_configuration_is_reachable(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_redoc_with_non_default_url_configuration_is_reachable(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url=test_configuration.settings.swagger_redoc_url) @@ -65,28 +66,28 @@ async def test_exposed_openapi_schema_is_reachable(self, client: AsyncClient): # Assert assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}" - async def test_enabled_router_is_reachable(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_enabled_router_is_reachable(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url=f"/v1{EndpointRoute.ME_INFO}") # Assert assert response.status_code == 401, f"Expected 401, got {response.status_code}: {response.text}" - async def test_disabled_router_is_unreachable(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_disabled_router_is_unreachable(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url=f"/v1/{test_configuration.settings.disabled_routers[0]}") # Assert assert response.status_code == 404, f"Expected 404, got {response.status_code}: {response.text}" - async def test_hidden_router_is_reachable(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_hidden_router_is_reachable(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url=f"/v1/{test_configuration.settings.hidden_routers[0]}") # Assert assert response.status_code == 401, f"Expected 401, got {response.status_code}: {response.text}" - async def test_hidden_router_is_not_in_exposed_openapi_schema(self, client: AsyncClient, test_configuration: SimpleNamespace): + async def test_hidden_router_is_not_in_exposed_openapi_schema(self, client: AsyncClient, test_configuration: Configuration): # Act response = await client.get(url="/openapi.json") diff --git a/api/tests/unit/test_utils/test_carbon.py b/api/tests/unit/test_utils/test_carbon.py index a59f535ca..37a919c34 100644 --- a/api/tests/unit/test_utils/test_carbon.py +++ b/api/tests/unit/test_utils/test_carbon.py @@ -47,10 +47,12 @@ def test_get_carbon_footprint_return_footprint(self, mocker): mocked_electricity_mix = mocker.patch("api.utils.carbon.electricity_mixes.find_electricity_mix") mocked_electricity_mix.return_value = SimpleNamespace(adpe=1, pe=2, gwp=3, wue=4) mocked_compute_llm_impacts = mocker.patch("api.utils.carbon.compute_llm_impacts") - mocked_compute_llm_impacts.return_value = dict_to_namespace({ - "energy": {"value": 1}, - "gwp": {"value": 3}, - }) + mocked_compute_llm_impacts.return_value = dict_to_namespace( + { + "energy": {"value": 1}, + "gwp": {"value": 3}, + } + ) active_params = 1 total_params = 1 model_zone = ProviderCarbonFootprintZone.WOR diff --git a/api/tests/unit/use_case/admin/providers/test_getoneproviderusecase.py b/api/tests/unit/use_case/admin/providers/test_getoneproviderusecase.py new file mode 100644 index 000000000..5730bdcf6 --- /dev/null +++ b/api/tests/unit/use_case/admin/providers/test_getoneproviderusecase.py @@ -0,0 +1,83 @@ +from unittest.mock import AsyncMock + +import pytest + +from api.domain.provider.errors import ProviderNotFoundError +from api.domain.userinfo.errors import UserIsNotAdminError +from api.tests.unit.use_case.factories import ProviderFactory, UserInfoFactory +from api.use_cases.admin.providers._getoneproviderusecase import GetOneProviderCommand, GetOneProviderUseCase, GetOneProviderUseCaseSuccess + + +@pytest.fixture +def provider_repository(): + return AsyncMock() + + +@pytest.fixture +def user_info_repository(): + return AsyncMock() + + +@pytest.fixture +def use_case(provider_repository, user_info_repository): + return GetOneProviderUseCase(provider_repository=provider_repository, user_info_repository=user_info_repository) + + +@pytest.fixture +def admin_user_info(): + return UserInfoFactory(id=1, admin=True) + + +@pytest.fixture +def unauthorized_user_info(): + return UserInfoFactory(id=3, without_permission=True, limits=[]) + + +@pytest.fixture +def sample_provider(): + return ProviderFactory(id=42, user_id=1) + + +class TestGetOneProviderUseCase: + @pytest.mark.asyncio + async def test_should_return_provider_when_user_is_admin_and_provider_exists( + self, use_case, provider_repository, user_info_repository, admin_user_info, sample_provider + ): + # Arrange + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.provider_repository.get_one_provider.return_value = sample_provider + + # Act + result = await use_case.execute(command=GetOneProviderCommand(user_id=admin_user_info.id, provider_id=42)) + + # Assert + assert isinstance(result, GetOneProviderUseCaseSuccess) + assert result.provider == sample_provider + user_info_repository.get_user_info.assert_called_once_with(user_id=admin_user_info.id) + provider_repository.get_one_provider.assert_called_once_with(42) + + @pytest.mark.asyncio + async def test_should_return_provider_not_found_error_when_provider_does_not_exist(self, use_case, provider_repository, admin_user_info): + # Arrange + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.provider_repository.get_one_provider.return_value = None + + # Act + result = await use_case.execute(command=GetOneProviderCommand(user_id=admin_user_info.id, provider_id=99)) + + # Assert + assert isinstance(result, ProviderNotFoundError) + assert result.provider_id == 99 + provider_repository.get_one_provider.assert_called_once_with(99) + + @pytest.mark.asyncio + async def test_should_return_user_is_not_admin_error_when_user_is_not_admin(self, use_case, provider_repository, unauthorized_user_info): + # Arrange + use_case.user_info_repository.get_user_info.return_value = unauthorized_user_info + + # Act + result = await use_case.execute(command=GetOneProviderCommand(user_id=unauthorized_user_info.id, provider_id=42)) + + # Assert + assert isinstance(result, UserIsNotAdminError) + provider_repository.get_one_provider.assert_not_called() diff --git a/api/tests/unit/use_case/admin/providers/test_getprovidersusecase.py b/api/tests/unit/use_case/admin/providers/test_getprovidersusecase.py new file mode 100644 index 000000000..bd22ee56e --- /dev/null +++ b/api/tests/unit/use_case/admin/providers/test_getprovidersusecase.py @@ -0,0 +1,98 @@ +from unittest.mock import AsyncMock + +import pytest + +from api.domain import SortOrder +from api.domain.provider.entities import ProviderPage, ProviderSortField +from api.domain.userinfo.errors import UserIsNotAdminError +from api.tests.unit.use_case.factories import ProviderFactory, UserInfoFactory +from api.use_cases.admin.providers import GetProvidersCommand, GetProvidersUseCase, GetProvidersUseCaseSuccess + + +@pytest.fixture +def provider_repository(): + return AsyncMock() + + +@pytest.fixture +def user_info_repository(): + return AsyncMock() + + +@pytest.fixture +def use_case(provider_repository, user_info_repository): + return GetProvidersUseCase(provider_repository=provider_repository, user_info_repository=user_info_repository) + + +@pytest.fixture +def admin_user_info(): + return UserInfoFactory(id=1, admin=True) + + +@pytest.fixture +def unauthorized_user_info(): + return UserInfoFactory(id=3, without_permission=True, limits=[]) + + +@pytest.fixture +def sample_providers(): + return [ProviderFactory(id=1, user_id=1), ProviderFactory(id=2, user_id=1)] + + +@pytest.fixture +def sample_command(): + return GetProvidersCommand(user_id=1, router_id=None, offset=0, limit=10, sort_by=ProviderSortField.ID, sort_order=SortOrder.ASC) + + +class TestGetProvidersUseCase: + @pytest.mark.asyncio + async def test_should_return_providers_when_user_is_admin( + self, use_case, provider_repository, user_info_repository, admin_user_info, sample_providers, sample_command + ): + # Arrange + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.provider_repository.get_providers_page.return_value = ProviderPage(total=2, data=sample_providers) + + # Act + result = await use_case.execute(command=sample_command) + + # Assert + assert isinstance(result, GetProvidersUseCaseSuccess) + assert result.page.data == sample_providers + assert result.page.total == 2 + user_info_repository.get_user_info.assert_called_once_with(user_id=admin_user_info.id) + + @pytest.mark.asyncio + async def test_should_return_user_is_not_admin_error_when_user_is_not_an_admin(self, use_case, provider_repository, unauthorized_user_info): + # Arrange + use_case.user_info_repository.get_user_info.return_value = unauthorized_user_info + + # Act + result = await use_case.execute( + command=GetProvidersCommand( + user_id=unauthorized_user_info.id, router_id=None, offset=0, limit=10, sort_by=ProviderSortField.ID, sort_order=SortOrder.ASC + ) + ) + + # Assert + assert isinstance(result, UserIsNotAdminError) + provider_repository.get_providers_page.assert_not_called() + + @pytest.mark.asyncio + async def test_should_forward_pagination_params_to_repository(self, use_case, provider_repository, admin_user_info, sample_providers): + # Arrange + use_case.user_info_repository.get_user_info.return_value = admin_user_info + use_case.provider_repository.get_providers_page.return_value = ProviderPage(total=2, data=sample_providers) + command = GetProvidersCommand(user_id=1, router_id=42, offset=5, limit=20, sort_by=ProviderSortField.MODEL_NAME, sort_order=SortOrder.DESC) + + # Act + await use_case.execute(command=command) + + # Assert + provider_repository.get_providers_page.assert_called_once_with( + router_id=42, + limit=20, + offset=5, + sort_by=ProviderSortField.MODEL_NAME, + sort_order=SortOrder.DESC, + ) diff --git a/api/tests/unit/use_case/admin/routers/test_getroutersusecase.py b/api/tests/unit/use_case/admin/routers/test_getroutersusecase.py index f6816cd9b..c27206742 100644 --- a/api/tests/unit/use_case/admin/routers/test_getroutersusecase.py +++ b/api/tests/unit/use_case/admin/routers/test_getroutersusecase.py @@ -2,7 +2,8 @@ import pytest -from api.domain.router.entities import RouterPage, RouterSortField, SortOrder +from api.domain import SortField, SortOrder +from api.domain.router.entities import RouterPage from api.domain.userinfo.errors import UserIsNotAdminError from api.tests.unit.use_case.factories import RouterFactory, UserInfoFactory from api.use_cases.admin.routers import GetRoutersCommand, GetRoutersUseCase, GetRoutersUseCaseSuccess @@ -40,7 +41,7 @@ def sample_routers(): @pytest.fixture def sample_command(): - return GetRoutersCommand(user_id=1, offset=0, limit=10, sort_by=RouterSortField.ID, sort_order=SortOrder.ASC) + return GetRoutersCommand(user_id=1, offset=0, limit=10, sort_by=SortField.ID, sort_order=SortOrder.ASC) class TestGetRoutersUseCase: @@ -57,8 +58,8 @@ async def test_should_return_routers_when_user_is_admin( # Assert assert isinstance(result, GetRoutersUseCaseSuccess) - assert result.routers == sample_routers - assert result.total == 2 + assert result.router_page.data == sample_routers + assert result.router_page.total == 2 user_info_repository.get_user_info.assert_called_once_with(user_id=admin_user_info.id) @pytest.mark.asyncio @@ -70,7 +71,7 @@ async def test_should_return_cannot_read_routers_error_when_user_is_not_an_admin # Act result = await use_case.execute( - command=GetRoutersCommand(user_id=unauthorized_user_info.id, offset=0, limit=10, sort_by=RouterSortField.ID, sort_order=SortOrder.ASC) + command=GetRoutersCommand(user_id=unauthorized_user_info.id, offset=0, limit=10, sort_by=SortField.ID, sort_order=SortOrder.ASC) ) # Assert @@ -82,7 +83,7 @@ async def test_should_forward_pagination_params_to_repository(self, use_case, ro # Arrange use_case.user_info_repository.get_user_info.return_value = admin_user_info use_case.router_repository.get_routers_page.return_value = RouterPage(total=2, data=sample_routers) - command = GetRoutersCommand(user_id=1, offset=5, limit=20, sort_by=RouterSortField.NAME, sort_order=SortOrder.DESC) + command = GetRoutersCommand(user_id=1, offset=5, limit=20, sort_by=SortField.NAME, sort_order=SortOrder.DESC) # Act await use_case.execute(command=command) @@ -91,6 +92,6 @@ async def test_should_forward_pagination_params_to_repository(self, use_case, ro router_repository.get_routers_page.assert_called_once_with( limit=20, offset=5, - sort_by=RouterSortField.NAME, + sort_by=SortField.NAME, sort_order=SortOrder.DESC, ) diff --git a/api/use_cases/admin/providers/__init__.py b/api/use_cases/admin/providers/__init__.py index 7a50c756f..b4938a504 100644 --- a/api/use_cases/admin/providers/__init__.py +++ b/api/use_cases/admin/providers/__init__.py @@ -1,5 +1,7 @@ from ._createproviderusecase import CreateProviderCommand, CreateProviderUseCase, CreateProviderUseCaseSuccess from ._deleteproviderusecase import DeleteProviderCommand, DeleteProviderUseCase, DeleteProviderUseCaseSuccess +from ._getoneproviderusecase import GetOneProviderCommand, GetOneProviderUseCase, GetOneProviderUseCaseSuccess +from ._getprovidersusecase import GetProvidersCommand, GetProvidersUseCase, GetProvidersUseCaseSuccess __all__ = [ "CreateProviderCommand", @@ -8,4 +10,10 @@ "DeleteProviderCommand", "DeleteProviderUseCase", "DeleteProviderUseCaseSuccess", + "GetOneProviderCommand", + "GetOneProviderUseCase", + "GetOneProviderUseCaseSuccess", + "GetProvidersCommand", + "GetProvidersUseCase", + "GetProvidersUseCaseSuccess", ] diff --git a/api/use_cases/admin/providers/_getoneproviderusecase.py b/api/use_cases/admin/providers/_getoneproviderusecase.py new file mode 100644 index 000000000..1566a835f --- /dev/null +++ b/api/use_cases/admin/providers/_getoneproviderusecase.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass + +from api.domain.provider import Provider, ProviderRepository +from api.domain.provider.errors import ProviderNotFoundError +from api.domain.userinfo import UserInfoRepository +from api.domain.userinfo.errors import UserIsNotAdminError + + +@dataclass +class GetOneProviderCommand: + user_id: int + provider_id: int + + +@dataclass +class GetOneProviderUseCaseSuccess: + provider: Provider + + +type GetOneProviderUseCaseResult = GetOneProviderUseCaseSuccess | ProviderNotFoundError | UserIsNotAdminError + + +class GetOneProviderUseCase: + def __init__(self, provider_repository: ProviderRepository, user_info_repository: UserInfoRepository): + self.provider_repository = provider_repository + self.user_info_repository = user_info_repository + + async def execute( + self, + command: GetOneProviderCommand, + ) -> GetOneProviderUseCaseResult: + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) + + if not user_info.is_admin: + return UserIsNotAdminError() + + provider = await self.provider_repository.get_one_provider(command.provider_id) + + if not provider: + return ProviderNotFoundError(command.provider_id) + return GetOneProviderUseCaseSuccess(provider=provider) diff --git a/api/use_cases/admin/providers/_getprovidersusecase.py b/api/use_cases/admin/providers/_getprovidersusecase.py new file mode 100644 index 000000000..8c77e9f39 --- /dev/null +++ b/api/use_cases/admin/providers/_getprovidersusecase.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +from api.domain import SortOrder +from api.domain.provider import ProviderRepository +from api.domain.provider.entities import ProviderPage, ProviderSortField +from api.domain.userinfo import UserInfoRepository +from api.domain.userinfo.errors import UserIsNotAdminError + + +@dataclass +class GetProvidersCommand: + user_id: int + router_id: int | None + offset: int + limit: int + sort_by: ProviderSortField + sort_order: SortOrder + + +@dataclass +class GetProvidersUseCaseSuccess: + page: ProviderPage + + +type GetProvidersUseCaseResult = GetProvidersUseCaseSuccess | UserIsNotAdminError + + +class GetProvidersUseCase: + def __init__(self, provider_repository: ProviderRepository, user_info_repository: UserInfoRepository): + self.provider_repository = provider_repository + self.user_info_repository = user_info_repository + + async def execute( + self, + command: GetProvidersCommand, + ) -> GetProvidersUseCaseResult: + user_info = await self.user_info_repository.get_user_info(user_id=command.user_id) + + if not user_info.is_admin: + return UserIsNotAdminError() + + providers_page = await self.provider_repository.get_providers_page( + router_id=command.router_id, limit=command.limit, offset=command.offset, sort_by=command.sort_by, sort_order=command.sort_order + ) + + return GetProvidersUseCaseSuccess(page=providers_page) diff --git a/api/use_cases/admin/routers/_getroutersusecase.py b/api/use_cases/admin/routers/_getroutersusecase.py index 369d8b4f8..e6f69cc63 100644 --- a/api/use_cases/admin/routers/_getroutersusecase.py +++ b/api/use_cases/admin/routers/_getroutersusecase.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from api.domain import SortField, SortOrder from api.domain.router import RouterRepository -from api.domain.router.entities import Router, RouterSortField, SortOrder +from api.domain.router.entities import RouterPage from api.domain.userinfo import UserInfoRepository from api.domain.userinfo.errors import UserIsNotAdminError @@ -11,14 +12,13 @@ class GetRoutersCommand: user_id: int offset: int limit: int - sort_by: RouterSortField + sort_by: SortField sort_order: SortOrder @dataclass class GetRoutersUseCaseSuccess: - routers: list[Router] - total: int + router_page: RouterPage type GetRoutersUseCaseResult = GetRoutersUseCaseSuccess | UserIsNotAdminError @@ -45,4 +45,4 @@ async def execute( sort_order=command.sort_order, ) - return GetRoutersUseCaseSuccess(routers=router_page.data, total=router_page.total) + return GetRoutersUseCaseSuccess(router_page=router_page) diff --git a/playground/app/features/usage/state.py b/playground/app/features/usage/state.py index a31dbe092..a5d8913f8 100644 --- a/playground/app/features/usage/state.py +++ b/playground/app/features/usage/state.py @@ -82,15 +82,17 @@ async def load_entities(self): def usage_rows(self) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for row in self.entities: - rows.append({ - "date": row.created, - "endpoint": row.endpoint, - "key": row.key, - "model": row.model, - "tokens": "" if row.total_tokens == 0 else f"{row.prompt_tokens} → {row.completion_tokens}", - "cost": "" if row.cost == 0.0 or row.cost is None else f"{row.cost:.4f}", - "kgCO2eq": "" if row.kgco2eq is None else f"{round(row.kgco2eq, 5)}", - }) + rows.append( + { + "date": row.created, + "endpoint": row.endpoint, + "key": row.key, + "model": row.model, + "tokens": "" if row.total_tokens == 0 else f"{row.prompt_tokens} → {row.completion_tokens}", + "cost": "" if row.cost == 0.0 or row.cost is None else f"{row.cost:.4f}", + "kgCO2eq": "" if row.kgco2eq is None else f"{round(row.kgco2eq, 5)}", + } + ) return rows ############################################################