Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from api.infrastructure.model import ModelProviderGateway
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
from api.schemas.core.context import RequestContext
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase, GetOneProviderUseCase, GetProvidersUseCase
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
from api.use_cases.models import GetModelsUseCase
from api.utils.configuration import configuration
Expand Down Expand Up @@ -98,3 +98,17 @@ def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(ge
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)


def get_one_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneProviderUseCase:
return GetOneProviderUseCase(
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)


def get_providers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetProvidersUseCase:
return GetProvidersUseCase(
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)
23 changes: 23 additions & 0 deletions api/domain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Generic, TypeVar


class SortField(StrEnum):
ID = "id"
NAME = "name"
CREATED = "created"


class SortOrder(StrEnum):
ASC = "asc"
DESC = "desc"


T = TypeVar("T")


@dataclass
class EntitiesPage(Generic[T]):
total: int
data: list[T]
21 changes: 20 additions & 1 deletion api/domain/provider/_providerrepository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod

from api.domain import SortOrder
from api.domain.model.entities import Metric
from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType
from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderPage, ProviderSortField, ProviderType
from api.domain.provider.errors import ProviderAlreadyExistsError


Expand Down Expand Up @@ -32,3 +33,21 @@ async def delete_provider(
provider_id: int,
) -> Provider | None:
pass

@abstractmethod
async def get_one_provider(
self,
provider_id: int,
) -> Provider | None:
pass

@abstractmethod
async def get_providers_page(
self,
router_id: int | None,
limit: int,
offset: int,
sort_by: ProviderSortField = ProviderSortField.ID,
sort_order: SortOrder = SortOrder.ASC,
) -> ProviderPage:
pass
10 changes: 10 additions & 0 deletions api/domain/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pycountry
from pydantic import Field, constr

from api.domain import EntitiesPage
from api.domain.model.entities import ModelType
from api.schemas import BaseModel
from api.schemas.core.models import Metric
Expand Down Expand Up @@ -57,6 +58,15 @@ class ProviderType(str, Enum):
}


class ProviderSortField(str, Enum):
ID = "id"
MODEL_NAME = "model_name"
CREATED = "created"


ProviderPage = EntitiesPage["Provider"]


class Provider(BaseModel):
object: Literal["provider"] = "provider"
id: int = Field(..., description="Provider ID.") # fmt: off
Expand Down
5 changes: 3 additions & 2 deletions api/domain/router/_routerrepository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod

from api.domain import SortField, SortOrder
from api.domain.model import ModelType as RouterType
from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage, RouterSortField, SortOrder
from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage
from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError


Expand All @@ -19,7 +20,7 @@ async def get_routers_page(
self,
limit: int,
offset: int,
sort_by: RouterSortField = RouterSortField.ID,
sort_by: SortField = SortField.ID,
sort_order: SortOrder = SortOrder.ASC,
) -> RouterPage:
pass
Expand Down
18 changes: 2 additions & 16 deletions api/domain/router/entities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from enum import StrEnum

from pydantic import BaseModel, Field

from api.domain import EntitiesPage
from api.domain.model import ModelType as RouterType


Expand All @@ -11,21 +11,7 @@ class RouterLoadBalancingStrategy(StrEnum):
LEAST_BUSY = "least_busy"


class RouterSortField(StrEnum):
ID = "id"
NAME = "name"
CREATED = "created"


class SortOrder(StrEnum):
ASC = "asc"
DESC = "desc"


@dataclass
class RouterPage:
total: int
data: list["Router"]
RouterPage = EntitiesPage["Router"]


class Router(BaseModel):
Expand Down
61 changes: 3 additions & 58 deletions api/endpoints/admin/providers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Literal

from fastapi import Body, Depends, Path, Query, Request, Security
from fastapi.responses import JSONResponse, Response
from fastapi import Body, Depends, Path, Request, Security
from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession

from api.endpoints.admin import router
from api.helpers._accesscontroller import AccessController
from api.helpers.models import ModelRegistry
from api.schemas.admin.providers import Provider, Providers, UpdateProvider
from api.schemas.admin.providers import UpdateProvider
from api.schemas.admin.roles import PermissionType
from api.utils.dependencies import get_model_registry, get_postgres_session
from api.utils.variables import EndpointRoute
Expand Down Expand Up @@ -41,56 +39,3 @@ async def update_provider(
)

return Response(status_code=204)


@router.get(
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}",
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.PROVIDE_MODELS]))],
status_code=200,
response_model=Provider,
)
async def get_provider(
request: Request,
provider: int = Path(description="The ID of the provider to get."),
postgres_session: AsyncSession = Depends(get_postgres_session),
model_registry: ModelRegistry = Depends(get_model_registry),
) -> JSONResponse:
"""
Get a model provider by router and provider IDs.
"""
providers = await model_registry.get_providers(router_id=router, provider_id=provider, postgres_session=postgres_session)
provider = providers[0]

return JSONResponse(status_code=200, content=provider.model_dump())


@router.get(
path=EndpointRoute.ADMIN_PROVIDERS,
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))],
status_code=200,
response_model=Providers,
)
async def get_providers(
request: Request,
router: int | None = Query(default=None, description="Filter providers by router ID."),
offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."),
limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."),
order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."),
order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."),
postgres_session: AsyncSession = Depends(get_postgres_session),
model_registry: ModelRegistry = Depends(get_model_registry),
) -> JSONResponse:
"""
Get all model providers for a router.
"""
providers = await model_registry.get_providers(
router_id=router,
provider_id=None,
postgres_session=postgres_session,
offset=offset,
limit=limit,
order_by=order_by,
order_direction=order_direction,
)

return JSONResponse(status_code=200, content=Providers(data=providers).model_dump())
Loading
Loading