Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/badges/coverage.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"schemaVersion":1,"label":"coverage","message":"50.59%","color":"red"}
{"schemaVersion":1,"label":"coverage","message":"51.03%","color":"red"}
77 changes: 34 additions & 43 deletions api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@
from api.infrastructure.postgres import PostgresKeyRepository, PostgresProviderRepository, PostgresRouterRepository, PostgresUserInfoRepository
from api.schemas.core.context import RequestContext
from api.use_cases.admin.providers import CreateProviderUseCase, DeleteProviderUseCase
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase
from api.use_cases.admin.routers import CreateRouterUseCase, DeleteRouterUseCase, GetOneRouterUseCase, GetRoutersUseCase, UpdateRouterUseCase
from api.use_cases.models import GetModelsUseCase
from api.utils.configuration import configuration
from api.utils.context import global_context, request_context


def get_request_context() -> ContextVar[RequestContext]:
return request_context


def get_master_key() -> str:
return configuration.settings.auth_master_key


# databases
async def get_postgres_session() -> AsyncGenerator[AsyncSession]:
session_factory = global_context.postgres_session_factory
async with session_factory() as postgres_session:
Expand All @@ -28,10 +37,7 @@ async def get_postgres_session() -> AsyncGenerator[AsyncSession]:
raise


def get_request_context() -> ContextVar[RequestContext]:
return request_context


# repositories
def _router_repository(session: AsyncSession) -> PostgresRouterRepository:
return PostgresRouterRepository(postgres_session=session, app_title=configuration.settings.app_title)

Expand All @@ -40,6 +46,11 @@ def _user_info_repository(session: AsyncSession) -> PostgresUserInfoRepository:
return PostgresUserInfoRepository(postgres_session=session)


def get_key_repository(postgres_session: AsyncSession = Depends(get_postgres_session)) -> KeyRepository:
return PostgresKeyRepository(postgres_session=postgres_session)


# models use cases
def get_models_use_case(
postgres_session: AsyncSession = Depends(get_postgres_session),
request_context: RequestContext = Depends(get_request_context),
Expand All @@ -51,45 +62,33 @@ def get_models_use_case(
)


def create_provider_use_case_factory(
postgres_session: AsyncSession = Depends(get_postgres_session),
) -> CreateProviderUseCase:
return CreateProviderUseCase(
router_repository=_router_repository(postgres_session),
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
provider_gateway=ModelProviderGateway(),
user_info_repository=_user_info_repository(postgres_session),
)
# routers use cases
def get_one_router_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetOneRouterUseCase:
return GetOneRouterUseCase(router_repository=_router_repository(postgres_session), user_info_repository=_user_info_repository(postgres_session))


def get_one_router_use_case_factory(
postgres_session: AsyncSession = Depends(get_postgres_session),
) -> GetOneRouterUseCase:
return GetOneRouterUseCase(
router_repository=_router_repository(postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)


def get_routers_use_case_factory(
postgres_session: AsyncSession = Depends(get_postgres_session),
) -> GetRoutersUseCase:
return GetRoutersUseCase(
router_repository=_router_repository(postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)
def get_routers_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> GetRoutersUseCase:
return GetRoutersUseCase(router_repository=_router_repository(postgres_session), user_info_repository=_user_info_repository(postgres_session))


def create_router_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateRouterUseCase:
return CreateRouterUseCase(
router_repository=_router_repository(postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)
return CreateRouterUseCase(router_repository=_router_repository(postgres_session), user_info_repository=_user_info_repository(postgres_session))


def delete_router_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> DeleteRouterUseCase:
return DeleteRouterUseCase(
return DeleteRouterUseCase(router_repository=_router_repository(postgres_session), user_info_repository=_user_info_repository(postgres_session))


def update_router_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> UpdateRouterUseCase:
return UpdateRouterUseCase(router_repository=_router_repository(postgres_session), user_info_repository=_user_info_repository(postgres_session))


# providers use cases
def create_provider_use_case_factory(postgres_session: AsyncSession = Depends(get_postgres_session)) -> CreateProviderUseCase:
return CreateProviderUseCase(
router_repository=_router_repository(postgres_session),
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
provider_gateway=ModelProviderGateway(),
user_info_repository=_user_info_repository(postgres_session),
)

Expand All @@ -99,11 +98,3 @@ def delete_provider_use_case_factory(postgres_session: AsyncSession = Depends(ge
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
user_info_repository=_user_info_repository(postgres_session),
)


def get_key_repository(postgres_session: AsyncSession = Depends(get_postgres_session)) -> KeyRepository:
return PostgresKeyRepository(postgres_session=postgres_session)


def get_master_key() -> str:
return configuration.settings.auth_master_key
8 changes: 8 additions & 0 deletions api/domain/router/_routerrepository.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ async def create_router(
@abstractmethod
async def delete_router(self, router_id: int) -> Router | None:
pass

@abstractmethod
async def get_aliases(self, filtered_aliases: list[str] | None = None) -> list[str]:
pass

@abstractmethod
async def update_router(self, router_to_update: Router) -> Router | RouterNameAlreadyExistsError:
pass
26 changes: 22 additions & 4 deletions api/domain/router/entities.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from dataclasses import dataclass
from enum import Enum
from enum import StrEnum

from pydantic import BaseModel, Field

from api.domain.model import ModelType as RouterType


class RouterLoadBalancingStrategy(str, Enum):
class RouterLoadBalancingStrategy(StrEnum):
SHUFFLE = "shuffle"
LEAST_BUSY = "least_busy"


class RouterSortField(str, Enum):
class RouterSortField(StrEnum):
ID = "id"
NAME = "name"
CREATED = "created"


class SortOrder(str, Enum):
class SortOrder(StrEnum):
ASC = "asc"
DESC = "desc"

Expand All @@ -42,3 +42,21 @@ class Router(BaseModel):
providers: int = Field(default=0, description="Number of providers in the router.") # fmt: off
created: int = Field(..., description="Time of creation, as Unix timestamp.") # fmt: off
updated: int = Field(..., description="Time of last update, as Unix timestamp.") # fmt: off

def with_name(self, name: str) -> "Router":
return self.model_copy(update={"name": name})

def with_type(self, router_type: RouterType) -> "Router":
return self.model_copy(update={"type": router_type})

def with_load_balancing_strategy(self, strategy: RouterLoadBalancingStrategy) -> "Router":
return self.model_copy(update={"load_balancing_strategy": strategy})

def with_cost_prompt_tokens(self, prompt_tokens: float) -> "Router":
return self.model_copy(update={"cost_prompt_tokens": prompt_tokens})

def with_cost_completion_tokens(self, completion_tokens: float) -> "Router":
return self.model_copy(update={"cost_completion_tokens": completion_tokens})

def with_aliases(self, aliases: list[str]) -> "Router":
return self.model_copy(update={"aliases": aliases})
2 changes: 1 addition & 1 deletion api/endpoints/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

router = APIRouter(prefix="/v1", tags=[RouterName.ADMIN.title()])

from . import organizations, providers, roles, routers, tokens, users # noqa: F401 E402
from . import organizations, providers, roles, tokens, users # noqa: F401 E402
40 changes: 0 additions & 40 deletions api/endpoints/admin/routers.py

This file was deleted.

Loading