Skip to content

Commit fab113b

Browse files
committed
feat: add the text embedding endpoint for LLM serving
1 parent ff429f4 commit fab113b

File tree

6 files changed

+315
-13
lines changed

6 files changed

+315
-13
lines changed

app/api/routers/generative.py

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@
1010
from fastapi import APIRouter, Depends, Request, Body, Query
1111
from fastapi.encoders import jsonable_encoder
1212
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
13-
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
14-
from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole
13+
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
14+
from app.domain import (
15+
Tags,
16+
OpenAIChatRequest,
17+
OpenAIChatResponse,
18+
OpenAIEmbeddingsRequest,
19+
OpenAIEmbeddingsResponse,
20+
PromptMessage,
21+
PromptRole,
22+
)
1523
from app.model_services.base import AbstractModelService
1624
from app.utils import get_settings, get_prompt_from_messages
1725
from app.api.utils import get_rate_limiter
@@ -21,6 +29,7 @@
2129
PATH_GENERATE = "/generate"
2230
PATH_GENERATE_ASYNC = "/stream/generate"
2331
PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
32+
PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
2433

2534
router = APIRouter()
2635
config = get_settings()
@@ -134,7 +143,7 @@ async def generate_text_stream(
134143

135144
@router.post(
136145
PATH_OPENAI_COMPLETIONS,
137-
tags=[Tags.Generative.name],
146+
tags=[Tags.OpenAICompatible.name],
138147
response_model=None,
139148
dependencies=[Depends(cms_globals.props.current_active_user)],
140149
description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions",
@@ -162,6 +171,7 @@ def generate_chat_completions(
162171
"""
163172

164173
messages = request_data.messages
174+
model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
165175
stream = request_data.stream
166176
max_tokens = request_data.max_tokens
167177
temperature = request_data.temperature
@@ -224,7 +234,7 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
224234
id=tracking_id,
225235
object="chat.completion",
226236
created=int(time.time()),
227-
model=model_service.model_name,
237+
model=model,
228238
choices=[
229239
{
230240
"index": 0,
@@ -239,14 +249,100 @@ async def _stream(prompt: str, max_tokens: int, temperature: float) -> AsyncGene
239249
return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})
240250

241251

252+
@router.post(
253+
PATH_OPENAI_EMBEDDINGS,
254+
tags=[Tags.OpenAICompatible.name],
255+
response_model=None,
256+
dependencies=[Depends(cms_globals.props.current_active_user)],
257+
description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint",
258+
)
259+
def embed_texts(
260+
request: Request,
261+
request_data: Annotated[OpenAIEmbeddingsRequest, Body(
262+
description="Text(s) to be embedded", media_type="application/json"
263+
)],
264+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
265+
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
266+
) -> JSONResponse:
267+
"""
268+
Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.
269+
270+
Args:
271+
request (Request): The request object.
272+
request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
273+
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
274+
model_service (AbstractModelService): The model service dependency.
275+
276+
Returns:
277+
JSONResponse: A response containing the embeddings of the text(s).
278+
"""
279+
tracking_id = tracking_id or str(uuid.uuid4())
280+
281+
if not hasattr(model_service, "create_embeddings"):
282+
error_response = {
283+
"error": {
284+
"message": "Model does not support embeddings",
285+
"type": "invalid_request_error",
286+
"param": "model",
287+
"code": "model_not_supported",
288+
}
289+
}
290+
return JSONResponse(
291+
content=error_response,
292+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
293+
headers={"x-cms-tracking-id": tracking_id},
294+
)
295+
296+
input_text = request_data.input
297+
model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
298+
299+
if isinstance(input_text, str):
300+
input_texts = [input_text]
301+
else:
302+
input_texts = input_text
303+
304+
try:
305+
embeddings_data = []
306+
307+
for i, embedding in enumerate(model_service.create_embeddings(input_texts)):
308+
embeddings_data.append({
309+
"object": "embedding",
310+
"embedding": embedding,
311+
"index": i,
312+
})
313+
314+
response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model)
315+
316+
return JSONResponse(
317+
content=jsonable_encoder(response),
318+
headers={"x-cms-tracking-id": tracking_id},
319+
)
320+
321+
except Exception as e:
322+
logger.error("Failed to create embeddings")
323+
logger.exception(e)
324+
error_response = {
325+
"error": {
326+
"message": f"Failed to create embeddings: {str(e)}",
327+
"type": "server_error",
328+
"code": "internal_error",
329+
}
330+
}
331+
return JSONResponse(
332+
content=error_response,
333+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
334+
headers={"x-cms-tracking-id": tracking_id},
335+
)
336+
337+
242338
def _empty_prompt_error() -> Iterable[str]:
243339
yield "ERROR: No prompt text provided\n"
244340

245341

246342
def _send_usage_metrics(handler: str, prompt_token_num: int, completion_token_num: int) -> None:
247343
cms_prompt_tokens.labels(handler=handler).observe(prompt_token_num)
248-
logger.debug(f"Sent prompt tokens usage: {prompt_token_num}")
344+
logger.debug("Sent prompt tokens usage: %s", prompt_token_num)
249345
cms_completion_tokens.labels(handler=handler).observe(completion_token_num)
250-
logger.debug(f"Sent completion tokens usage: {completion_token_num}")
346+
logger.debug("Sent completion tokens usage: %s", completion_token_num)
251347
cms_total_tokens.labels(handler=handler).observe(prompt_token_num + completion_token_num)
252-
logger.debug(f"Sent total tokens usage: {prompt_token_num + completion_token_num}")
348+
logger.debug("Sent total tokens usage: %s", prompt_token_num + completion_token_num)

app/domain.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Optional, Dict, Any
2+
from typing import List, Optional, Dict, Any, Union
33

44
from fastapi import HTTPException
55
from starlette.status import HTTP_400_BAD_REQUEST
@@ -27,6 +27,7 @@ class Tags(str, Enum):
2727
Evaluating = "Evaluate the deployed model with trainer export"
2828
Authentication = "Authenticate registered users"
2929
Generative = "Generate text based on the input prompt"
30+
OpenAICompatible = "Compatible with OpenAI APIs"
3031

3132

3233
class TagsStreamable(str, Enum):
@@ -185,6 +186,7 @@ class OpenAIChatRequest(BaseModel):
185186
messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model")
186187
stream: bool = Field(..., description="Whether to stream the response")
187188
max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
189+
model: str = Field(..., description="The name of the model used for generating the completion")
188190
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
189191

190192

@@ -194,3 +196,14 @@ class OpenAIChatResponse(BaseModel):
194196
created: int = Field(..., description="The timestamp when the completion was generated")
195197
model: str = Field(..., description="The name of the model used for generating the completion")
196198
choices: List = Field(..., description="The generated messages and their metadata")
199+
200+
201+
class OpenAIEmbeddingsRequest(BaseModel):
202+
input: Union[str, List[str]] = Field(..., description="Input text or list of texts to embed")
203+
model: str = Field(..., description="The name of the model used for creating the embeddings")
204+
205+
206+
class OpenAIEmbeddingsResponse(BaseModel):
207+
object: str = Field(..., description="The type of the response")
208+
data: List[Dict[str, Any]] = Field(..., description="List of embedding objects")
209+
model: str = Field(..., description="The name of the model used for creating the embeddings")

app/model_services/base.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3-
from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable
3+
from typing import Any, List, Iterable, Tuple, final, Optional, Generic, TypeVar, Protocol, AsyncIterable, Union
44
from app.config import Settings
55
from app.domain import ModelCard, Annotation
66

@@ -17,7 +17,7 @@ def tracker_client(self) -> Any:
1717
T = TypeVar("T", bound=_TrainerCommon)
1818

1919
class AbstractModelService(ABC, Generic[T]):
20-
"""An abstract base class defining the common interface for all model services."""
20+
"""An abstract base class defining the common interface for NER model services."""
2121

2222
@abstractmethod
2323
def __init__(self, config: Settings, *args: Any, **kwargs: Any) -> None:
@@ -200,6 +200,29 @@ def generate_async(self, prompt: str, *args: Any, **kwargs: Any) -> AsyncIterabl
200200

201201
raise NotImplementedError
202202

203+
def create_embeddings(
204+
self,
205+
text: Union[str, List[str]],
206+
*args: Any,
207+
**kwargs: Any
208+
) -> Union[List[float], List[List[float]]]:
209+
"""
210+
Creates embeddings for a given text or list of texts.
211+
212+
Args:
213+
text (Union[str, List[str]]): The text(s) to be embedded.
214+
*args (Any): Additional positional arguments to be passed to this method.
215+
**kwargs (Any): Additional keyword arguments to be passed to this method.
216+
217+
Returns:
218+
Union[List[float], List[List[float]]]: The embedding vector(s) for the text(s).
219+
220+
Raises:
221+
NotImplementedError: If the method is not implemented by the subclass.
222+
"""
223+
224+
raise NotImplementedError
225+
203226
def train_supervised(self, *args: Any, **kwargs: Any) -> Tuple[bool, str, str]:
204227
"""
205228
Initiates supervised training on the model.

app/model_services/huggingface_llm_model.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import logging
33
import asyncio
4+
import torch
45
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable
6+
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union
67
from transformers import (
78
AutoModelForCausalLM,
89
AutoTokenizer,
@@ -307,3 +308,50 @@ async def generate_async(
307308
return
308309
finally:
309310
logger.debug("Chat response generation completed")
311+
312+
def create_embeddings(
313+
self,
314+
text: Union[str, List[str]],
315+
*args: Any,
316+
**kwargs: Any
317+
) -> Union[List[float], List[List[float]]]:
318+
"""
319+
Creates embeddings for a given text or list of texts using the model's hidden states.
320+
321+
Args:
322+
text (Union[str, List[str]]): The text(s) to be embedded.
323+
*args (Any): Additional positional arguments to be passed to this method.
324+
**kwargs (Any): Additional keyword arguments to be passed to this method.
325+
326+
Returns:
327+
List[float], List[List[float]]: The embedding vector(s) for the text(s).
328+
329+
Raises:
330+
NotImplementedError: If the model doesn't support embeddings.
331+
"""
332+
333+
self.model.eval()
334+
335+
inputs = self.tokenizer(
336+
text,
337+
add_special_tokens=False,
338+
return_tensors="pt",
339+
padding=True,
340+
truncation=True,
341+
)
342+
343+
if non_default_device_is_available(self._config.DEVICE):
344+
inputs.to(get_settings().DEVICE)
345+
346+
with torch.no_grad():
347+
outputs = self.model(**inputs, output_hidden_states=True)
348+
349+
last_hidden_state = outputs.hidden_states[-1]
350+
attention_mask = inputs["attention_mask"]
351+
masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1)
352+
sum_hidden_states = masked_hidden_states.sum(dim=1)
353+
num_tokens = attention_mask.sum(dim=1, keepdim=True)
354+
embeddings = sum_hidden_states / num_tokens
355+
356+
results = embeddings.cpu().numpy().tolist()
357+
return results[0] if isinstance(text, str) else results

tests/app/api/test_serving_hf_llm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def llm_app(llm_model_service):
3131

3232
@pytest.fixture(scope="function")
3333
def client(llm_model_service):
34+
llm_model_service.model_name = "HuggingFace LLM model"
3435
llm_model_service.generate.return_value = "Yeah."
36+
llm_model_service.create_embeddings.return_value = [[1.0, 2.0, 3.0]]
3537
app = get_generative_server(config, msd_overwritten=lambda: llm_model_service)
3638
app.dependency_overrides[cms_globals.props.current_active_user] = lambda: None
3739
client = TestClient(app)
@@ -82,6 +84,7 @@ async def test_generate_chat_completions(llm_model_service, llm_app):
8284
"content": "Who are you?"
8385
}
8486
],
87+
"model": "HuggingFace LLM model",
8588
"stream": True,
8689
"max_tokens": 128,
8790
"temperature": 0.7
@@ -98,3 +101,22 @@ async def test_generate_chat_completions(llm_model_service, llm_app):
98101
assert response.text.startswith("data:")
99102
assert "id" in response.text
100103
assert "chat.completion.chunk" in response.text
104+
105+
106+
def test_create_embeddings(client):
107+
request_data = {
108+
"input": ["Alright"],
109+
"model": "HuggingFace LLM model",
110+
}
111+
response = client.post(
112+
"/v1/embeddings",
113+
data=json.dumps(request_data),
114+
headers={"Content-Type": "application/json"},
115+
)
116+
assert response.status_code == 200
117+
assert response.headers["content-type"] == "application/json"
118+
assert response.json() == {
119+
"object": "list",
120+
"data": [{"object": "embedding", "embedding": [1.0, 2.0, 3.0], "index": 0}],
121+
"model": "HuggingFace LLM model"
122+
}

0 commit comments

Comments
 (0)