Skip to content

Commit 3b043ee

Browse files
authored
Revert "feat: add the endpoint compatible with OpenAI client protocols"
This reverts commit bcea8fe.
1 parent a4ff387 commit 3b043ee

File tree

10 files changed

+60
-566
lines changed

10 files changed

+60
-566
lines changed

app/api/routers/generative.py

Lines changed: 8 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
1-
import json
21
import logging
3-
import time
4-
import uuid
52
import app.api.globals as cms_globals
63

7-
from typing import Union, Iterable, AsyncGenerator
84
from typing_extensions import Annotated
95
from fastapi import APIRouter, Depends, Request, Body, Query
10-
from fastapi.encoders import jsonable_encoder
11-
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
12-
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
13-
from app.domain import Tags, OpenAIChatRequest, OpenAIChatResponse, PromptMessage, PromptRole
6+
from fastapi.responses import PlainTextResponse, StreamingResponse
7+
from app.domain import Tags
148
from app.model_services.base import AbstractModelService
15-
from app.utils import get_settings, get_prompt_from_messages
9+
from app.utils import get_settings
1610
from app.api.utils import get_rate_limiter
17-
from app.api.dependencies import validate_tracking_id
1811

1912
PATH_GENERATE = "/generate"
2013
PATH_GENERATE_ASYNC = "/stream/generate"
21-
PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
2214

2315
router = APIRouter()
2416
config = get_settings()
@@ -39,8 +31,6 @@ def generate_text(
3931
request: Request,
4032
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
4133
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512,
42-
temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7,
43-
tracking_id: Union[str, None] = Depends(validate_tracking_id),
4434
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
4535
) -> PlainTextResponse:
4636
"""
@@ -50,27 +40,13 @@ def generate_text(
5040
request (Request): The request object.
5141
prompt (str): The prompt to be sent to the model.
5242
max_tokens (int): The maximum number of tokens to generate.
53-
temperature (float): The temperature of the generated text.
54-
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
5543
model_service (AbstractModelService): The model service dependency.
5644
5745
Returns:
5846
PlainTextResponse: A response containing the generated text.
5947
"""
6048

61-
tracking_id = tracking_id or str(uuid.uuid4())
62-
if prompt:
63-
return PlainTextResponse(
64-
model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature),
65-
headers={"x-cms-tracking-id": tracking_id},
66-
status_code=HTTP_200_OK,
67-
)
68-
else:
69-
return PlainTextResponse(
70-
_empty_prompt_error(),
71-
headers={"x-cms-tracking-id": tracking_id},
72-
status_code=HTTP_400_BAD_REQUEST,
73-
)
49+
return PlainTextResponse(model_service.generate(prompt, max_tokens=max_tokens))
7450

7551

7652
@router.post(
@@ -84,8 +60,6 @@ async def generate_text_stream(
8460
request: Request,
8561
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
8662
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512,
87-
temperature: Annotated[float, Query(description="The temperature of the generated text", gt=0.0, lt=1.0)] = 0.7,
88-
tracking_id: Union[str, None] = Depends(validate_tracking_id),
8963
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
9064
) -> StreamingResponse:
9165
"""
@@ -95,121 +69,13 @@ async def generate_text_stream(
9569
request (Request): The request object.
9670
prompt (str): The prompt to be sent to the model.
9771
max_tokens (int): The maximum number of tokens to generate.
98-
temperature (float): The temperature of the generated text.
99-
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
10072
model_service (AbstractModelService): The model service dependency.
10173
10274
Returns:
10375
StreamingResponse: A streaming response containing the text generated in near real-time.
10476
"""
10577

106-
tracking_id = tracking_id or str(uuid.uuid4())
107-
if prompt:
108-
return StreamingResponse(
109-
model_service.generate_async(prompt, max_tokens=max_tokens, temperature=temperature),
110-
media_type="text/event-stream",
111-
headers={"x-cms-tracking-id": tracking_id},
112-
status_code=HTTP_200_OK,
113-
)
114-
else:
115-
return StreamingResponse(
116-
_empty_prompt_error(),
117-
media_type="text/event-stream",
118-
headers={"x-cms-tracking-id": tracking_id},
119-
status_code=HTTP_400_BAD_REQUEST,
120-
)
121-
122-
123-
@router.post(
124-
"/v1/chat/completions",
125-
tags=[Tags.Generative.name],
126-
response_model=None,
127-
dependencies=[Depends(cms_globals.props.current_active_user)],
128-
description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions",
129-
)
130-
def generate_chat_completions(
131-
request: Request,
132-
request_data: Annotated[OpenAIChatRequest, Body(
133-
description="OpenAI-like completion request", media_type="application/json"
134-
)],
135-
tracking_id: Union[str, None] = Depends(validate_tracking_id),
136-
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
137-
) -> Union[StreamingResponse, JSONResponse]:
138-
"""
139-
Generate chat response based on messages, mimicking OpenAI's /v1/chat/completions endpoint.
140-
141-
Args:
142-
request (Request): The request object.
143-
request_data (OpenAIChatRequest): The request data containing model, messages, and stream.
144-
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
145-
model_service (AbstractModelService): The model service dependency.
146-
147-
Returns:
148-
StreamingResponse: A OpenAI-like response containing the text generated in near real-time.
149-
JSONResponse: A response containing an error message if the prompt messages are empty.
150-
"""
151-
152-
messages = request_data.messages
153-
stream = request_data.stream
154-
max_tokens = request_data.max_tokens
155-
temperature = request_data.temperature
156-
157-
if not messages:
158-
error_response = {
159-
"error": {
160-
"message": "No prompt messages provided",
161-
"type": "invalid_request_error",
162-
"param": "messages",
163-
"code": "missing_field",
164-
}
165-
}
166-
return JSONResponse(content=error_response, status_code=HTTP_400_BAD_REQUEST)
167-
168-
async def _stream(p: str, mt: int, t: float) -> AsyncGenerator:
169-
data = {
170-
"id": tracking_id or str(uuid.uuid4()),
171-
"object": "chat.completion.chunk",
172-
"choices": [{"delta": {"role": PromptRole.ASSISTANT.value}}],
173-
}
174-
yield f"data: {json.dumps(data)}\n\n"
175-
async for chunk in model_service.generate_async(p, max_tokens=mt, temperature=t):
176-
data = {
177-
"choices": [
178-
{
179-
"delta": {"content": chunk}
180-
}
181-
],
182-
"object": "chat.completion.chunk",
183-
}
184-
yield f"data: {json.dumps(data)}\n\n"
185-
yield "data: [DONE]\n\n"
186-
187-
prompt = get_prompt_from_messages(model_service.tokenizer, messages) # type: ignore
188-
if stream:
189-
return StreamingResponse(
190-
_stream(prompt, max_tokens, temperature),
191-
media_type="text/event-stream"
192-
)
193-
else:
194-
generated_text = model_service.generate(prompt, max_tokens=max_tokens, temperature=temperature)
195-
completion = OpenAIChatResponse(
196-
id=str(uuid.uuid4()),
197-
object="chat.completion",
198-
created=int(time.time()),
199-
model=model_service.model_name,
200-
choices=[
201-
{
202-
"index": 0,
203-
"message": PromptMessage(
204-
role=PromptRole.ASSISTANT,
205-
content=generated_text,
206-
),
207-
"finish_reason": "stop",
208-
}
209-
]
210-
)
211-
return JSONResponse(content=jsonable_encoder(completion))
212-
213-
214-
def _empty_prompt_error() -> Iterable[str]:
215-
yield "ERROR: No prompt text provided\n"
78+
return StreamingResponse(
79+
model_service.generate_async(prompt, max_tokens=max_tokens),
80+
media_type="text/event-stream"
81+
)

app/api/utils.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ async def init_vllm_engine(app: FastAPI,
286286
"""
287287

288288
try:
289-
# Import necessary vLLM components
290289
from vllm.utils import FlexibleArgumentParser
291290
from vllm.engine.arg_utils import AsyncEngineArgs
292291
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
@@ -299,19 +298,16 @@ async def init_vllm_engine(app: FastAPI,
299298
)
300299
from vllm import SamplingParams, TokensPrompt
301300
except ImportError:
302-
# Raise a custom exception if vLLM is not installed
303-
raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.")
301+
logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
304302

305303
parser = FlexibleArgumentParser()
306304
parser = make_arg_parser(parser)
307305
args = parser.parse_args([])
308306
validate_parsed_serve_args(args)
309-
310307
args.model = model_dir_path
311308
args.dtype = "float16"
312309
args.served_model_name = [model_name]
313-
args.max_model_len = 2048 # The default batched length (2048) needs to be higher than max_model_len.
314-
# args.tokenizer = model_dir_path # Uncomment if your tokenizer is in a different path or needs explicit setting.
310+
# args.tokenizer = model_dir_path
315311
args.log_level = log_level
316312

317313
exit_stack = contextlib.AsyncExitStack()
@@ -321,44 +317,37 @@ async def init_vllm_engine(app: FastAPI,
321317
disable_frontend_multiprocessing=True,
322318
)
323319
)
324-
325320
tokenizer = await engine.get_tokenizer()
326321
vllm_config = await engine.get_vllm_config()
327322
model_config = await engine.get_model_config()
328-
329323
await init_app_state(engine, vllm_config, app.state, args)
330324

331325
async def generate_text(
332326
request: Request,
333327
prompt: Annotated[str, Body(description="The prompt to be sent to the model", media_type="text/plain")],
334328
max_tokens: Annotated[int, Query(description="The maximum number of tokens to generate", gt=0)] = 512
335329
) -> StreamingResponse:
336-
"""
337-
Custom endpoint for streaming text generation.
338-
This endpoint takes a raw text prompt and streams back the generated text.
339-
It applies a chat template to the prompt internally for model compatibility.
340-
"""
341330
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
342331

343332
params = SamplingParams(max_tokens=max_tokens)
344-
345333
conversation, _ = parse_chat_messages(messages, model_config, tokenizer, content_format="string") # type: ignore
346-
prompt_tokens = apply_hf_chat_template( # type: ignore
347-
tokenizer,
348-
conversation=conversation,
349-
tools=None,
350-
add_generation_prompt=True,
351-
continue_final_message=False,
352-
chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
353-
tokenize=True,
334+
prompt = TokensPrompt(
335+
prompt_token_ids=apply_hf_chat_template( # type: ignore
336+
tokenizer,
337+
conversation=conversation,
338+
tools=None,
339+
add_generation_prompt=True,
340+
continue_final_message=False,
341+
chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\nUser: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\nAssistant: {{ message['content'] }}\n{% endif %}\n{% endfor %}\nAssistant:",
342+
tokenize=True,
343+
)
354344
)
355-
prompt_obj = TokensPrompt(prompt_token_ids=prompt_tokens)
356345

357346
async def _stream() -> AsyncGenerator[bytes, None]:
358347
start = 0
359-
async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt_obj, sampling_params=params):
348+
async for output in engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params):
360349
text = output.outputs[0].text
361-
yield text[start:].encode("utf-8")
350+
yield text[start:] # type: ignore
362351
start = len(text)
363352

364353
return StreamingResponse(_stream(), media_type="text/event-stream")

app/domain.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -167,30 +167,3 @@ class Doc(BaseModel):
167167
text: str = Field(description="The text from which the entities are extracted")
168168
ents: List[Entity] = Field(description="The list of extracted entities")
169169
title: Optional[str] = Field(default=None, description="The headline of the text")
170-
171-
172-
class PromptRole(Enum):
173-
SYSTEM = "system"
174-
USER = "user"
175-
ASSISTANT = "assistant"
176-
TOOL = "tool"
177-
178-
179-
class PromptMessage(BaseModel):
180-
role: PromptRole = Field(description="The role who generates the message")
181-
content: str = Field(description="The actual text of the message")
182-
183-
184-
class OpenAIChatRequest(BaseModel):
185-
messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model")
186-
stream: bool = Field(..., description="Whether to stream the response")
187-
max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
188-
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
189-
190-
191-
class OpenAIChatResponse(BaseModel):
192-
id: str
193-
object: str
194-
created: int
195-
model: str
196-
choices: List

app/model_services/huggingface_llm_model.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,6 @@ def init_model(self) -> None:
168168
logger.warning("Model service is already initialised and can be initialised only once")
169169
else:
170170
self._model, self._tokenizer = self.load_model(self._model_pack_path)
171-
if non_default_device_is_available(get_settings().DEVICE):
172-
self._model.to(get_settings().DEVICE)
173171
if self._enable_trainer:
174172
logger.error("Trainers are not yet implemented for HuggingFace Generative models")
175173

@@ -193,20 +191,13 @@ def annotate(self, text: str) -> List[Annotation]:
193191
def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
194192
raise NotImplementedError("Batch annotation is not yet implemented for HuggingFace Generative models")
195193

196-
def generate(
197-
self,
198-
prompt: str,
199-
max_tokens: int = 512,
200-
temperature: float = 0.7,
201-
**kwargs: Any
202-
) -> str:
194+
def generate(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> str:
203195
"""
204196
Generates text based on the prompt.
205197
206198
Args:
207199
prompt (str): The prompt for the text generation
208200
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
209-
temperature (float): The temperature for the text generation. Defaults to 0.7.
210201
**kwargs (Any): Additional keyword arguments to be passed to this method.
211202
212203
Returns:
@@ -223,8 +214,8 @@ def generate(
223214
inputs=inputs.input_ids,
224215
attention_mask=inputs.attention_mask,
225216
max_new_tokens=max_tokens,
226-
do_sample=False,
227-
temperature=temperature,
217+
do_sample=True,
218+
temperature=0.7,
228219
top_p=0.9,
229220
)
230221

@@ -236,20 +227,13 @@ def generate(
236227

237228
return generated_text
238229

239-
async def generate_async(
240-
self,
241-
prompt: str,
242-
max_tokens: int = 512,
243-
temperature: float = 0.7,
244-
**kwargs: Any
245-
) -> AsyncIterable:
230+
async def generate_async(self, prompt: str, max_tokens: int = 512, **kwargs: Any) -> AsyncIterable:
246231
"""
247232
Asynchronously generates text stream based on the prompt.
248233
249234
Args:
250235
prompt (str): The prompt for the text generation.
251236
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
252-
temperature (float): The temperature for the text generation. Defaults to 0.7.
253237
**kwargs (Any): Additional keyword arguments to be passed to the model loader.
254238
255239
Returns:
@@ -273,7 +257,7 @@ async def generate_async(
273257
streamer=streamer,
274258
max_new_tokens=max_tokens,
275259
do_sample=True,
276-
temperature=temperature,
260+
temperature=0.7,
277261
top_p=0.9,
278262
)
279263

0 commit comments

Comments
 (0)