Skip to content

Commit 17fc6aa

Browse files
benjaminpiliaBenjamin PILIA
andauthored
feat(refacto): get provider endpoint toward clean architecture (#767)
* BSR: clean integration tests * Refacto get provider into clean architecture * BSR: remove begin_nested in create_router * fix created and updated in get provider * refacto get providers into clean architecture * Minor fix * BSR: fix TestGetRoutersPage tests * Update unit coverage badge --------- Co-authored-by: Benjamin PILIA <benjamin.pilia@protonmail.com> Co-authored-by: benjaminpilia <benjaminpilia@users.noreply.github.com>
1 parent ede5b37 commit 17fc6aa

37 files changed

+961
-245
lines changed

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"51.03%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"50.10%","color":"red"}

api/dependencies.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from api.infrastructure.model import ModelProviderGateway
99
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
1010
from api.schemas.core.context import RequestContext
11-
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase
11+
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase, GetOneProviderUseCase, GetProvidersUseCase
1212
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
1313
from api.use_cases.models import GetModelsUseCase
1414
from api.utils.configuration import configuration
@@ -98,3 +98,17 @@ def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(ge
9898
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
9999
user_info_repository=_user_info_repository(postgres_session),
100100
)
101+
102+
103+
def get_one_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneProviderUseCase:
104+
return GetOneProviderUseCase(
105+
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
106+
user_info_repository=_user_info_repository(postgres_session),
107+
)
108+
109+
110+
def get_providers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetProvidersUseCase:
111+
return GetProvidersUseCase(
112+
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
113+
user_info_repository=_user_info_repository(postgres_session),
114+
)

api/domain/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dataclasses import dataclass
2+
from enum import StrEnum
3+
from typing import Generic, TypeVar
4+
5+
6+
class SortField(StrEnum):
7+
ID = "id"
8+
NAME = "name"
9+
CREATED = "created"
10+
11+
12+
class SortOrder(StrEnum):
13+
ASC = "asc"
14+
DESC = "desc"
15+
16+
17+
T = TypeVar("T")
18+
19+
20+
@dataclass
21+
class EntitiesPage(Generic[T]):
22+
total: int
23+
data: list[T]

api/domain/provider/_providerrepository.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
22

3+
from api.domain import SortOrder
34
from api.domain.model.entities import Metric
4-
from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderType
5+
from api.domain.provider.entities import Provider, ProviderCarbonFootprintZone, ProviderPage, ProviderSortField, ProviderType
56
from api.domain.provider.errors import ProviderAlreadyExistsError
67

78

@@ -32,3 +33,21 @@ async def delete_provider(
3233
provider_id: int,
3334
) -> Provider | None:
3435
pass
36+
37+
@abstractmethod
38+
async def get_one_provider(
39+
self,
40+
provider_id: int,
41+
) -> Provider | None:
42+
pass
43+
44+
@abstractmethod
45+
async def get_providers_page(
46+
self,
47+
router_id: int | None,
48+
limit: int,
49+
offset: int,
50+
sort_by: ProviderSortField = ProviderSortField.ID,
51+
sort_order: SortOrder = SortOrder.ASC,
52+
) -> ProviderPage:
53+
pass

api/domain/provider/entities.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pycountry
55
from pydantic import Field, constr
66

7+
from api.domain import EntitiesPage
78
from api.domain.model.entities import ModelType
89
from api.schemas import BaseModel
910
from api.schemas.core.models import Metric
@@ -57,6 +58,15 @@ class ProviderType(str, Enum):
5758
}
5859

5960

61+
class ProviderSortField(str, Enum):
62+
ID = "id"
63+
MODEL_NAME = "model_name"
64+
CREATED = "created"
65+
66+
67+
ProviderPage = EntitiesPage["Provider"]
68+
69+
6070
class Provider(BaseModel):
6171
object: Literal["provider"] = "provider"
6272
id: int = Field(..., description="Provider ID.") # fmt: off

api/domain/router/_routerrepository.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
22

3+
from api.domain import SortField, SortOrder
34
from api.domain.model import ModelType as RouterType
4-
from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage, RouterSortField, SortOrder
5+
from api.domain.router.entities import Router, RouterLoadBalancingStrategy, RouterPage
56
from api.domain.router.errors import RouterAliasAlreadyExistsError, RouterNameAlreadyExistsError
67

78

@@ -19,7 +20,7 @@ async def get_routers_page(
1920
self,
2021
limit: int,
2122
offset: int,
22-
sort_by: RouterSortField = RouterSortField.ID,
23+
sort_by: SortField = SortField.ID,
2324
sort_order: SortOrder = SortOrder.ASC,
2425
) -> RouterPage:
2526
pass

api/domain/router/entities.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from dataclasses import dataclass
21
from enum import StrEnum
32

43
from pydantic import BaseModel, Field
54

5+
from api.domain import EntitiesPage
66
from api.domain.model import ModelType as RouterType
77

88

@@ -11,21 +11,7 @@ class RouterLoadBalancingStrategy(StrEnum):
1111
LEAST_BUSY = "least_busy"
1212

1313

14-
class RouterSortField(StrEnum):
15-
ID = "id"
16-
NAME = "name"
17-
CREATED = "created"
18-
19-
20-
class SortOrder(StrEnum):
21-
ASC = "asc"
22-
DESC = "desc"
23-
24-
25-
@dataclass
26-
class RouterPage:
27-
total: int
28-
data: list["Router"]
14+
RouterPage = EntitiesPage["Router"]
2915

3016

3117
class Router(BaseModel):

api/endpoints/admin/providers.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Literal
2-
3-
from fastapi import Body, Depends, Path, Query, Request, Security
4-
from fastapi.responses import JSONResponse, Response
1+
from fastapi import Body, Depends, Path, Request, Security
2+
from fastapi.responses import Response
53
from sqlalchemy.ext.asyncio import AsyncSession
64

75
from api.endpoints.admin import router
86
from api.helpers._accesscontroller import AccessController
97
from api.helpers.models import ModelRegistry
10-
from api.schemas.admin.providers import Provider, Providers, UpdateProvider
8+
from api.schemas.admin.providers import UpdateProvider
119
from api.schemas.admin.roles import PermissionType
1210
from api.utils.dependencies import get_model_registry, get_postgres_session
1311
from api.utils.variables import EndpointRoute
@@ -41,56 +39,3 @@ async def update_provider(
4139
)
4240

4341
return Response(status_code=204)
44-
45-
46-
@router.get(
47-
path=EndpointRoute.ADMIN_PROVIDERS + "/{provider}",
48-
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.PROVIDE_MODELS]))],
49-
status_code=200,
50-
response_model=Provider,
51-
)
52-
async def get_provider(
53-
request: Request,
54-
provider: int = Path(description="The ID of the provider to get."),
55-
postgres_session: AsyncSession = Depends(get_postgres_session),
56-
model_registry: ModelRegistry = Depends(get_model_registry),
57-
) -> JSONResponse:
58-
"""
59-
Get a model provider by router and provider IDs.
60-
"""
61-
providers = await model_registry.get_providers(router_id=router, provider_id=provider, postgres_session=postgres_session)
62-
provider = providers[0]
63-
64-
return JSONResponse(status_code=200, content=provider.model_dump())
65-
66-
67-
@router.get(
68-
path=EndpointRoute.ADMIN_PROVIDERS,
69-
dependencies=[Security(dependency=AccessController(permissions=[PermissionType.ADMIN, PermissionType.PROVIDE_MODELS]))],
70-
status_code=200,
71-
response_model=Providers,
72-
)
73-
async def get_providers(
74-
request: Request,
75-
router: int | None = Query(default=None, description="Filter providers by router ID."),
76-
offset: int = Query(default=0, ge=0, description="The offset of the tokens to get."),
77-
limit: int = Query(default=10, ge=1, le=100, description="The limit of the tokens to get."),
78-
order_by: Literal["id", "model_name", "created"] = Query(default="id", description="The field to order the tokens by."),
79-
order_direction: Literal["asc", "desc"] = Query(default="asc", description="The direction to order the tokens by."),
80-
postgres_session: AsyncSession = Depends(get_postgres_session),
81-
model_registry: ModelRegistry = Depends(get_model_registry),
82-
) -> JSONResponse:
83-
"""
84-
Get all model providers for a router.
85-
"""
86-
providers = await model_registry.get_providers(
87-
router_id=router,
88-
provider_id=None,
89-
postgres_session=postgres_session,
90-
offset=offset,
91-
limit=limit,
92-
order_by=order_by,
93-
order_direction=order_direction,
94-
)
95-
96-
return JSONResponse(status_code=200, content=Providers(data=providers).model_dump())

0 commit comments

Comments
 (0)