66from fastapi .responses import JSONResponse , StreamingResponse
77from pydantic import BaseModel , Field
88from typing import Dict , List , Any , Optional , Generator , Tuple
9+ import json
910
1011from ..logger import get_logger
1112from ..logger .logger import get_request_count
@@ -75,6 +76,36 @@ class BatchGenerationResponse(BaseModel):
7576 responses : List [str ]
7677
7778
79+ def format_chat_messages (messages : List [ChatMessage ]) -> str :
80+ """
81+ Format a list of chat messages into a prompt string that the model can understand
82+
83+ Args:
84+ messages: List of ChatMessage objects with role and content
85+
86+ Returns:
87+ Formatted prompt string
88+ """
89+ formatted_messages = []
90+
91+ for msg in messages :
92+ role = msg .role .strip ().lower ()
93+
94+ if role == "system" :
95+ # System messages get special formatting
96+ formatted_messages .append (f"# System Instruction\n { msg .content } \n " )
97+ elif role == "user" :
98+ formatted_messages .append (f"User: { msg .content } " )
99+ elif role == "assistant" :
100+ formatted_messages .append (f"Assistant: { msg .content } " )
101+ else :
102+ # Default formatting for other roles
103+ formatted_messages .append (f"{ role .capitalize ()} : { msg .content } " )
104+
105+ # Join all messages with newlines
106+ return "\n \n " .join (formatted_messages )
107+
108+
78109@router .post ("/generate" , response_model = GenerationResponse )
79110async def generate_text (request : GenerationRequest ) -> GenerationResponse :
80111 """
@@ -105,8 +136,8 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
105136 # Merge model-specific params with request params
106137 generation_params .update (model_params )
107138
108- # Generate text
109- generated_text = model_manager .generate_text (
139+ # Generate text - properly await the async call
140+ generated_text = await model_manager .generate_text (
110141 prompt = request .prompt ,
111142 system_prompt = request .system_prompt ,
112143 ** generation_params
@@ -123,46 +154,51 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
123154
124155@router .post ("/chat" , response_model = ChatResponse )
125156async def chat_completion (request : ChatRequest ) -> ChatResponse :
126- """Chat completion endpoint similar to OpenAI's API"""
157+ """
158+ Chat completion API that formats messages into a prompt and returns the response
159+ """
127160 if not model_manager .current_model :
128161 raise HTTPException (status_code = 400 , detail = "No model is currently loaded" )
129162
163+ # Format messages into a prompt
164+ formatted_prompt = format_chat_messages (request .messages )
165+
166+ # If streaming is requested, return a streaming response
167+ if request .stream :
168+ return StreamingResponse (
169+ stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p ),
170+ media_type = "text/event-stream"
171+ )
172+
130173 try :
131- # Format messages into a prompt
132- formatted_prompt = "\n " .join ([f"{ msg .role } : { msg .content } " for msg in request .messages ])
133-
134- if request .stream :
135- # Return a streaming response
136- return StreamingResponse (
137- stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p ),
138- media_type = "text/event-stream"
139- )
140-
141174 # Get model-specific generation parameters
142175 model_params = get_model_generation_params (model_manager .current_model )
143176
144- # Update with request parameters
177+ # Prepare generation parameters
145178 generation_params = {
146179 "max_new_tokens" : request .max_tokens ,
147180 "temperature" : request .temperature ,
148- "top_p" : request .top_p ,
181+ "top_p" : request .top_p
149182 }
150183
151184 # Merge model-specific params with request params
152185 generation_params .update (model_params )
153186
154- # Generate text
155- response = model_manager .generate_text (
187+ # Generate completion
188+ generated_text = await model_manager .generate_text (
156189 prompt = formatted_prompt ,
157190 ** generation_params
158191 )
159192
193+ # Format response
160194 return ChatResponse (
161195 choices = [{
162196 "message" : {
163197 "role" : "assistant" ,
164- "content" : response
165- }
198+ "content" : generated_text
199+ },
200+ "index" : 0 ,
201+ "finish_reason" : "stop"
166202 }]
167203 )
168204 except Exception as e :
@@ -177,7 +213,9 @@ async def generate_stream(
177213 top_p : float ,
178214 system_prompt : Optional [str ]
179215) -> Generator [str , None , None ]:
180- """Generate text in a streaming fashion"""
216+ """
217+ Generate text in a streaming fashion
218+ """
181219 try :
182220 # Get model-specific generation parameters
183221 model_params = get_model_generation_params (model_manager .current_model )
@@ -187,26 +225,26 @@ async def generate_stream(
187225 "max_new_tokens" : max_tokens ,
188226 "temperature" : temperature ,
189227 "top_p" : top_p ,
190- "stream" : True
191228 }
192229
193230 # Merge model-specific params with request params
194231 generation_params .update (model_params )
195232
196- for token in model_manager .generate_stream (
233+ # Stream tokens
234+ async for token in model_manager .generate_stream (
197235 prompt = prompt ,
198236 system_prompt = system_prompt ,
199237 ** generation_params
200238 ):
201- # Format as a server-sent event
202- yield f"data: { token } \n \n "
203-
204- # End of stream marker
239+ # Format as server-sent event
240+ data = token .replace ("\n " , "\\ n" )
241+ yield f"data: { data } \n \n "
242+
243+ # End of stream
205244 yield "data: [DONE]\n \n "
206245 except Exception as e :
207246 logger .error (f"Streaming generation failed: { str (e )} " )
208- yield f"data: {{\" error\" : \" { str (e )} \" }}\n \n "
209- yield "data: [DONE]\n \n "
247+ yield f"data: [ERROR] { str (e )} \n \n "
210248
211249
212250async def stream_chat (
@@ -215,7 +253,9 @@ async def stream_chat(
215253 temperature : float ,
216254 top_p : float
217255) -> Generator [str , None , None ]:
218- """Stream chat completion tokens"""
256+ """
257+ Stream chat completion
258+ """
219259 try :
220260 # Get model-specific generation parameters
221261 model_params = get_model_generation_params (model_manager .current_model )
@@ -224,25 +264,27 @@ async def stream_chat(
224264 generation_params = {
225265 "max_new_tokens" : max_tokens ,
226266 "temperature" : temperature ,
227- "top_p" : top_p ,
228- "stream" : True
267+ "top_p" : top_p
229268 }
230269
231270 # Merge model-specific params with request params
232271 generation_params .update (model_params )
233272
234- for token in model_manager .generate_stream (
273+ # Generate streaming tokens
274+ async for token in model_manager .generate_stream (
235275 prompt = formatted_prompt ,
236276 ** generation_params
237277 ):
238- # Format as a server-sent event with proper JSON structure
239- yield f'data: {{"choices": [{{"delta": {{"content": "{ token } "}}}}]}}\n \n '
240-
278+ # Format as a server-sent event with the structure expected by chat clients
279+ data = json .dumps ({"role" : "assistant" , "content" : token })
280+ yield f"data: { data } \n \n "
281+
241282 # End of stream marker
242283 yield "data: [DONE]\n \n "
243284 except Exception as e :
244285 logger .error (f"Chat streaming failed: { str (e )} " )
245- yield f"data: {{\" error\" : \" { str (e )} \" }}\n \n "
286+ error_data = json .dumps ({"error" : str (e )})
287+ yield f"data: { error_data } \n \n "
246288 yield "data: [DONE]\n \n "
247289
248290
@@ -270,7 +312,7 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
270312
271313 responses = []
272314 for prompt in request .prompts :
273- generated_text = model_manager .generate_text (
315+ generated_text = await model_manager .generate_text (
274316 prompt = prompt ,
275317 system_prompt = request .system_prompt ,
276318 ** generation_params
0 commit comments