Skip to content

Commit fd08852

Browse files
committed
feat(models): decouple healthcheck to other capabilities
1 parent 4ddbc35 commit fd08852

File tree

13 files changed

+115
-161
lines changed

13 files changed

+115
-161
lines changed

api/clients/model/_albertmodelprovider.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
2-
from urllib.parse import urljoin
3-
4-
import httpx
52

63
from api.schemas.admin.providers import ProviderType
74
from api.schemas.core.models import ProviderEndpoints
8-
from api.utils.variables import EndpointRoute
95

106
from ._basemodelprovider import BaseModelProvider
117

@@ -45,24 +41,3 @@ def __init__(
4541
model_active_params=model_active_params,
4642
)
4743
self.type = ProviderType.ALBERT
48-
49-
async def get_max_context_length(self) -> int | None:
50-
url = urljoin(base=str(self.url), url=self.ENDPOINT_TABLE.get_endpoint(endpoint=EndpointRoute.MODELS).lstrip("/"))
51-
52-
try:
53-
async with httpx.AsyncClient() as client:
54-
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
55-
response.raise_for_status()
56-
except Exception as e:
57-
# TODO: remove exc_info=True and return error instead of exception
58-
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
59-
raise AssertionError(f"Model is not reachable ({e}).")
60-
61-
data = response.json()["data"]
62-
models = [model for model in data if model["id"] == self.model_name or self.model_name in model["aliases"]]
63-
assert len(models) == 1, f"Model not found ({self.model_name})."
64-
65-
model = models[0]
66-
max_context_length = model.get("max_context_length")
67-
68-
return max_context_length

api/clients/model/_basemodelprovider.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import httpx
1212
from redis.asyncio import Redis as AsyncRedis
1313

14+
from api.infrastructure.fastapi.schemas.models import Models
1415
from api.schemas.admin.providers import ProviderType
1516
from api.schemas.audio import AudioTranscription, CreateAudioTranscription
1617
from api.schemas.chat import ChatCompletionChunk, CreateChatCompletion
@@ -70,27 +71,50 @@ def import_module(type: ProviderType) -> "type[BaseModelProvider]":
7071

7172
return getattr(module, f"{type.capitalize()}ModelProvider")
7273

73-
@staticmethod
74-
async def get_max_context_length() -> int | None:
74+
async def healthcheck(self, redis_client: AsyncRedis) -> bool:
75+
"""
76+
Check if the model provider is healthy.
77+
"""
78+
request_content = RequestContent(endpoint=EndpointRoute.MODELS, method="GET")
79+
response = await self.forward_request(request_content=request_content, redis_client=redis_client)
80+
81+
if response.status_code != 200:
82+
return False
83+
84+
data = response.json()["data"]
85+
models = [model for model in data if model["id"] == self.model_name or self.model_name in model["aliases"]]
86+
87+
if not models:
88+
return False
89+
90+
return True
91+
92+
async def get_max_context_length(self, redis_client: AsyncRedis) -> int | None:
7593
"""
7694
Get the max context length of the model provider to store in the database. Useful
7795
to check provider consistency.
7896
"""
79-
pass
97+
request_content = RequestContent(endpoint=EndpointRoute.MODELS, method="GET")
98+
response = await self.forward_request(request_content=request_content, redis_client=redis_client)
99+
100+
if response.status_code != 200:
101+
return None
102+
103+
data = response.json()
104+
return data["max_context_length"]
80105

81-
async def get_vector_size(self) -> int | None:
106+
async def get_vector_size(self, redis_client: AsyncRedis) -> int | None:
82107
if self.ENDPOINT_TABLE.embeddings is None:
83108
return None
84109

85-
url = urljoin(base=self.url, url=self.ENDPOINT_TABLE.embeddings.lstrip("/"))
110+
request_content = RequestContent(endpoint=EndpointRoute.EMBEDDINGS, method="POST", json={"model": self.model_name, "input": "hello world"})
111+
response = await self.forward_request(request_content=request_content, redis_client=redis_client)
86112

87-
async with httpx.AsyncClient() as client:
88-
response = await client.post(url=url, headers=self.headers, json={"model": self.model_name, "input": "hello world"}, timeout=self.timeout)
89-
assert response.status_code == 200, f"Model is not reachable ({response.status_code} - {response.text})."
113+
if response.status_code != 200:
114+
return None
90115

91-
data = response.json()["data"]
116+
data = response.json()
92117
vector_size = len(data[0]["embedding"])
93-
94118
return vector_size
95119

96120
def _get_usage(self, request_content: RequestContent, response_data: dict | list[dict], request_latency: float | None = 0.0) -> Usage | None:
@@ -203,6 +227,13 @@ def _format_response(self, request_content: RequestContent, response: httpx.Resp
203227
response_data=response_data,
204228
).model_dump()
205229

230+
elif request_content.endpoint == EndpointRoute.MODELS:
231+
response_data = Models.build_from(
232+
provider_type=self.type,
233+
request_content=request_content,
234+
response_data=response_data,
235+
).model_dump()
236+
206237
elif request_content.endpoint == EndpointRoute.RERANK:
207238
response_data = Reranks.build_from(
208239
provider_type=self.type,

api/clients/model/_mistralmodelprovider.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
2-
from urllib.parse import urljoin
3-
4-
import httpx
52

63
from api.schemas.admin.providers import ProviderType
74
from api.schemas.core.models import ProviderEndpoints
8-
from api.utils.variables import EndpointRoute
95

106
from ._basemodelprovider import BaseModelProvider
117

@@ -45,24 +41,3 @@ def __init__(
4541
timeout=timeout,
4642
)
4743
self.type = ProviderType.MISTRAL
48-
49-
async def get_max_context_length(self) -> int | None:
50-
url = urljoin(base=str(self.url), url=self.ENDPOINT_TABLE.get_endpoint(endpoint=EndpointRoute.MODELS).lstrip("/"))
51-
52-
try:
53-
async with httpx.AsyncClient() as client:
54-
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
55-
response.raise_for_status()
56-
57-
except Exception as e:
58-
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
59-
raise AssertionError(f"Model is not reachable ({e}).")
60-
61-
data = response.json()["data"]
62-
models = [model for model in data if model["id"] == self.model_name]
63-
assert len(models) == 1, f"Model not found ({self.model_name})."
64-
65-
model = models[0]
66-
max_context_length = model.get("max_context_length")
67-
68-
return max_context_length

api/clients/model/_openaimodelprovider.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
2-
from urllib.parse import urljoin
3-
4-
import httpx
52

63
from api.schemas.admin.providers import ProviderType
74
from api.schemas.core.models import ProviderEndpoints
8-
from api.utils.variables import EndpointRoute
95

106
from ._basemodelprovider import BaseModelProvider
117

@@ -45,23 +41,3 @@ def __init__(
4541
timeout=timeout,
4642
)
4743
self.type = ProviderType.OPENAI
48-
49-
async def get_max_context_length(self) -> int | None:
50-
url = urljoin(base=str(self.url), url=self.ENDPOINT_TABLE.get_endpoint(endpoint=EndpointRoute.MODELS).lstrip("/"))
51-
52-
try:
53-
async with httpx.AsyncClient() as client:
54-
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
55-
response.raise_for_status()
56-
except Exception as e:
57-
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
58-
raise AssertionError(f"Model is not reachable ({e}).")
59-
60-
data = response.json()["data"]
61-
models = [model for model in data if model["id"] == self.model_name]
62-
assert len(models) == 1, f"Model not found ({self.model_name})."
63-
64-
model = models[0]
65-
max_context_length = model.get("max_context_length")
66-
67-
return max_context_length

api/clients/model/_teimodelprovider.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
2-
from urllib.parse import urljoin
3-
4-
import httpx
52

63
from api.schemas.admin.providers import ProviderType
74
from api.schemas.core.models import ProviderEndpoints
8-
from api.utils.variables import EndpointRoute
95

106
from ._basemodelprovider import BaseModelProvider
117

@@ -45,19 +41,3 @@ def __init__(
4541
model_active_params=model_active_params,
4642
)
4743
self.type = ProviderType.TEI
48-
49-
async def get_max_context_length(self) -> int | None:
50-
url = urljoin(base=self.url, url=self.ENDPOINT_TABLE.get_endpoint(endpoint=EndpointRoute.MODELS).lstrip("/"))
51-
52-
try:
53-
async with httpx.AsyncClient() as client:
54-
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
55-
response.raise_for_status()
56-
except Exception as e:
57-
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
58-
raise AssertionError(f"Model is not reachable ({e}).")
59-
60-
data = response.json()
61-
assert self.model_name == data["model_id"], f"Model not found ({self.model_name})."
62-
max_context_length = data.get("max_input_length")
63-
return max_context_length

api/clients/model/_vllmmodelprovider.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
2-
from urllib.parse import urljoin
3-
4-
import httpx
52

63
from api.schemas.admin.providers import ProviderType
74
from api.schemas.core.models import ProviderEndpoints
8-
from api.utils.variables import EndpointRoute
95

106
from ._basemodelprovider import BaseModelProvider
117

@@ -45,23 +41,3 @@ def __init__(
4541
model_active_params=model_active_params,
4642
)
4743
self.type = ProviderType.VLLM
48-
49-
async def get_max_context_length(self) -> int | None:
50-
url = urljoin(base=self.url, url=self.ENDPOINT_TABLE.get_endpoint(endpoint=EndpointRoute.MODELS).lstrip("/"))
51-
52-
try:
53-
async with httpx.AsyncClient() as client:
54-
response = await client.get(url=url, headers=self.headers, timeout=self.timeout)
55-
response.raise_for_status()
56-
except Exception as e:
57-
logger.error(f"Error getting max context length for {self.model_name}: {e}", exc_info=True)
58-
raise AssertionError(f"Model is not reachable ({e}).")
59-
60-
data = response.json()
61-
models = [model for model in data["data"] if model["id"] == self.model_name]
62-
assert len(models) == 1, f"Model not found ({self.model_name})."
63-
64-
model = models[0]
65-
max_context_length = model.get("max_model_len")
66-
67-
return max_context_length

api/dependencies.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from contextvars import ContextVar
33

44
from fastapi import Depends
5+
from redis.asyncio import Redis as AsyncRedis
56
from sqlalchemy.ext.asyncio import AsyncSession
67

78
from api.domain.key import KeyRepository
@@ -28,6 +29,14 @@ async def get_postgres_session() -> AsyncGenerator[AsyncSession]:
2829
raise
2930

3031

32+
async def get_redis_client() -> AsyncGenerator[AsyncRedis]:
33+
client = AsyncRedis(connection_pool=global_context.redis_pool)
34+
35+
yield client
36+
37+
await client.aclose()
38+
39+
3140
def get_request_context() -> ContextVar[RequestContext]:
3241
return request_context
3342

@@ -53,11 +62,12 @@ def get_models_use_case(
5362

5463
def create_provider_use_case_factory(
5564
postgres_session: AsyncSession = Depends(get_postgres_session),
65+
redis_client: AsyncRedis = Depends(get_redis_client),
5666
) -> CreateProviderUseCase:
5767
return CreateProviderUseCase(
5868
router_repository=_router_repository(postgres_session),
5969
provider_repository=PostgresProviderRepository(postgres_session=postgres_session),
60-
provider_gateway=ModelProviderGateway(),
70+
provider_gateway=ModelProviderGateway(redis_client=redis_client),
6171
user_info_repository=_user_info_repository(postgres_session),
6272
)
6373

api/domain/provider/entities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from enum import Enum
1+
from enum import Enum, StrEnum
22
from typing import Literal
33

44
import pycountry
@@ -14,7 +14,7 @@
1414
ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str)
1515

1616

17-
class ProviderType(str, Enum):
17+
class ProviderType(StrEnum):
1818
ALBERT = "albert"
1919
OPENAI = "openai"
2020
MISTRAL = "mistral"

api/helpers/models/_modelregistry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,9 @@ async def create_provider(
492492
model_total_params=model_total_params,
493493
model_active_params=model_active_params,
494494
)
495-
max_context_length = await provider.get_max_context_length()
495+
max_context_length = await provider.get_max_context_length(redis_client=self.redis_client)
496496
if router.type == ModelType.TEXT_EMBEDDINGS_INFERENCE:
497-
vector_size = await provider.get_vector_size()
497+
vector_size = await provider.get_vector_size(redis_client=self.redis_client)
498498
else:
499499
vector_size = None
500500

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from enum import Enum
2-
from typing import Literal
1+
from enum import StrEnum
2+
from typing import Annotated, Literal
33

44
from pydantic import Field
55

6-
from api.domain.model import Model as ModelEntity
6+
from api.domain.provider.entities import ProviderType
77
from api.schemas import BaseModel
8+
from api.schemas.core.models import RequestContent
89

910

1011
class ModelCosts(BaseModel):
1112
prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)")
1213
completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)")
1314

1415

15-
class ModelType(str, Enum):
16+
class ModelType(StrEnum):
1617
AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition"
1718
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
1819
IMAGE_TO_TEXT = "image-to-text"
@@ -21,10 +22,38 @@ class ModelType(str, Enum):
2122
TEXT_CLASSIFICATION = "text-classification"
2223

2324

24-
class Model(ModelEntity):
25-
object: Literal["model"] = "model"
25+
class Model(BaseModel):
26+
object: Annotated[Literal["model"], Field("model", description="Type of the object.")]
27+
id: Annotated[str, Field(..., description="The model identifier, which can be referenced in the API endpoints.")]
28+
type: Annotated[ModelType | None, Field(default=None, description="The type of the model, which can be used to identify the model type.", examples=["text-generation"])] # fmt: off
29+
aliases: Annotated[list[str], Field(default_factory=list, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]])] # fmt: off
30+
created: Annotated[int, Field(..., description="Time of creation, as Unix timestamp.")]
31+
owned_by: Annotated[str, Field(..., description="The organization that owns the model.")]
32+
max_context_length: Annotated[int | None, Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.")] # fmt: off
33+
costs: Annotated[ModelCosts, Field(default_factory=ModelCosts, description="Costs of the model.")]
34+
35+
@classmethod
36+
def build_from(cls, provider_type: ProviderType, request_content: RequestContent, response_data: dict) -> "Model":
37+
match provider_type:
38+
case ProviderType.ALBERT:
39+
return cls(**response_data)
40+
41+
case ProviderType.TEI:
42+
return cls(id=response_data["model_id"], created=0, owned_by="tei", max_context_length=response_data["max_input_length"])
43+
44+
case ProviderType.MISTRAL:
45+
return cls(**response_data)
46+
47+
case ProviderType.OPENAI:
48+
return cls(**response_data)
49+
50+
case ProviderType.VLLM:
51+
return cls(max_context_length=response_data["max_model_len"], **response_data)
52+
53+
case _:
54+
raise NotImplementedError(f"Provider {provider_type} not implemented")
2655

2756

2857
class Models(BaseModel):
29-
object: Literal["list"] = "list"
30-
data: list[Model]
58+
object: Annotated[Literal["list"], Field("list", description="Type of the object.")]
59+
data: Annotated[list[Model], Field(..., description="List of models.")]

0 commit comments

Comments
 (0)