8
8
from fastapi import FastAPI , Request
9
9
from fastapi .responses import JSONResponse , StreamingResponse
10
10
from pydantic import BaseModel
11
+ from asyncio import CancelledError , Task
11
12
12
13
13
14
class ChatCompletionRequest (BaseModel ):
@@ -35,14 +36,27 @@ def __init__(self, interpreter):
35
36
# Setup routes
36
37
self .app .post ("/chat/completions" )(self .chat_completion )
37
38
39
+ # Add a field to track the current request task
40
+ self ._current_request : Optional [Task ] = None
41
+
38
42
async def chat_completion (self , request : Request ):
39
43
"""Main chat completion endpoint"""
44
+ # Cancel any existing request
45
+ if self ._current_request and not self ._current_request .done ():
46
+ self ._current_request .cancel ()
47
+ try :
48
+ await self ._current_request
49
+ except CancelledError :
50
+ pass
51
+
40
52
body = await request .json ()
53
+ if self .interpreter .debug :
54
+ print ("Request body:" , body )
41
55
try :
42
56
req = ChatCompletionRequest (** body )
43
57
except Exception as e :
44
- print ("Validation error:" , str (e )) # Debug print
45
- print ("Request body:" , body ) # Print the request body
58
+ print ("Validation error:" , str (e ))
59
+ print ("Request body:" , body )
46
60
raise
47
61
48
62
# Filter out system message
@@ -75,18 +89,6 @@ async def _stream_response(self):
75
89
delta ["function_call" ] = choice .delta .function_call
76
90
if choice .delta .tool_calls is not None :
77
91
pass
78
- # Convert tool_calls to dict representation
79
- # delta["tool_calls"] = [
80
- # {
81
- # "index": tool_call.index,
82
- # "id": tool_call.id,
83
- # "type": tool_call.type,
84
- # "function": {
85
- # "name": tool_call.function.name,
86
- # "arguments": tool_call.function.arguments
87
- # }
88
- # } for tool_call in choice.delta.tool_calls
89
- # ]
90
92
91
93
choices .append (
92
94
{
@@ -108,11 +110,16 @@ async def _stream_response(self):
108
110
data ["system_fingerprint" ] = chunk .system_fingerprint
109
111
110
112
yield f"data: { json .dumps (data )} \n \n "
111
- except asyncio .CancelledError :
112
- # Set stop flag when stream is cancelled
113
- self .interpreter ._stop_flag = True
113
+
114
+ except CancelledError :
115
+ # Handle cancellation gracefully
116
+ print ("Request cancelled - cleaning up..." )
117
+
114
118
raise
119
+ except Exception as e :
120
+ print (f"Error in stream: { str (e )} " )
115
121
finally :
122
+ # Always send DONE message and cleanup
116
123
yield "data: [DONE]\n \n "
117
124
118
125
def run (self ):
0 commit comments