2
2
import json
3
3
import os
4
4
import time
5
- from typing import Any , Dict , List , Optional
5
+ from typing import Any , Dict , List , Optional , Union
6
6
7
7
import uvicorn
8
8
from fastapi import FastAPI , Request
9
9
from fastapi .responses import JSONResponse , StreamingResponse
10
+ from pydantic import BaseModel
10
11
11
12
12
- class ChatCompletionRequest :
13
- def __init__ (
14
- self ,
15
- messages : List [ Dict [ str , str ]],
16
- stream : bool = False ,
17
- model : Optional [str ] = None ,
18
- ):
19
- self . messages = messages
20
- self . stream = stream
21
- self . model = model
13
+ class ChatCompletionRequest ( BaseModel ) :
14
+ messages : List [ Dict [ str , Union [ str , list , None ]]]
15
+ stream : bool = False
16
+ model : Optional [ str ] = None
17
+ temperature : Optional [ float ] = None
18
+ max_tokens : Optional [int ] = None
19
+ top_p : Optional [ float ] = None
20
+ frequency_penalty : Optional [ float ] = None
21
+ presence_penalty : Optional [ float ] = None
22
+ tools : Optional [ List [ Dict [ str , Any ]]] = None
22
23
23
24
24
25
class Server :
@@ -33,30 +34,22 @@ def __init__(self, interpreter):
33
34
34
35
# Setup routes
35
36
self .app .post ("/v1/chat/completions" )(self .chat_completion )
36
- self .app .get ("/v1/models" )(self .list_models )
37
-
38
- async def list_models (self ):
39
- """List available models endpoint"""
40
- return {
41
- "data" : [
42
- {
43
- "id" : self .interpreter .model ,
44
- "object" : "model" ,
45
- "created" : int (time .time ()),
46
- "owned_by" : "open-interpreter" ,
47
- }
48
- ]
49
- }
50
37
51
38
async def chat_completion (self , request : Request ):
52
39
"""Main chat completion endpoint"""
53
40
body = await request .json ()
54
- req = ChatCompletionRequest (** body )
41
+ try :
42
+ req = ChatCompletionRequest (** body )
43
+ except Exception as e :
44
+ print ("Validation error:" , str (e )) # Debug print
45
+ print ("Request body:" , body ) # Print the request body
46
+ raise
47
+
48
+ # Filter out system message
49
+ req .messages = [msg for msg in req .messages if msg ["role" ] != "system" ]
55
50
56
51
# Update interpreter messages
57
- self .interpreter .messages = [
58
- {"role" : msg ["role" ], "content" : msg ["content" ]} for msg in req .messages
59
- ]
52
+ self .interpreter .messages = req .messages
60
53
61
54
if req .stream :
62
55
return StreamingResponse (
@@ -85,33 +78,54 @@ async def chat_completion(self, request: Request):
85
78
86
79
async def _stream_response (self ):
87
80
"""Stream the response in OpenAI-compatible format"""
88
- for chunk in self .interpreter .respond ():
89
- if chunk .get ("type" ) == "chunk" :
90
- data = {
91
- "id" : "chatcmpl-" + str (time .time ()),
92
- "object" : "chat.completion.chunk" ,
93
- "created" : int (time .time ()),
94
- "model" : self .interpreter .model ,
95
- "choices" : [
96
- {
97
- "index" : 0 ,
98
- "delta" : {"content" : chunk ["chunk" ]},
99
- "finish_reason" : None ,
100
- }
101
- ],
102
- }
103
- yield f"data: { json .dumps (data )} \n \n "
104
- await asyncio .sleep (0 )
81
+ async for chunk in self .interpreter .async_respond ():
82
+ # Convert tool_calls to dict if present
83
+ choices = []
84
+ for choice in chunk .choices :
85
+ delta = {}
86
+ if choice .delta :
87
+ if choice .delta .content is not None :
88
+ delta ["content" ] = choice .delta .content
89
+ if choice .delta .role is not None :
90
+ delta ["role" ] = choice .delta .role
91
+ if choice .delta .function_call is not None :
92
+ delta ["function_call" ] = choice .delta .function_call
93
+ if choice .delta .tool_calls is not None :
94
+ pass
95
+ # Convert tool_calls to dict representation
96
+ # delta["tool_calls"] = [
97
+ # {
98
+ # "index": tool_call.index,
99
+ # "id": tool_call.id,
100
+ # "type": tool_call.type,
101
+ # "function": {
102
+ # "name": tool_call.function.name,
103
+ # "arguments": tool_call.function.arguments
104
+ # }
105
+ # } for tool_call in choice.delta.tool_calls
106
+ # ]
107
+
108
+ choices .append (
109
+ {
110
+ "index" : choice .index ,
111
+ "delta" : delta ,
112
+ "finish_reason" : choice .finish_reason ,
113
+ }
114
+ )
115
+
116
+ data = {
117
+ "id" : chunk .id ,
118
+ "object" : chunk .object ,
119
+ "created" : chunk .created ,
120
+ "model" : chunk .model ,
121
+ "choices" : choices ,
122
+ }
123
+
124
+ if hasattr (chunk , "system_fingerprint" ):
125
+ data ["system_fingerprint" ] = chunk .system_fingerprint
126
+
127
+ yield f"data: { json .dumps (data )} \n \n "
105
128
106
- # Send final chunk
107
- data = {
108
- "id" : "chatcmpl-" + str (time .time ()),
109
- "object" : "chat.completion.chunk" ,
110
- "created" : int (time .time ()),
111
- "model" : self .interpreter .model ,
112
- "choices" : [{"index" : 0 , "delta" : {}, "finish_reason" : "stop" }],
113
- }
114
- yield f"data: { json .dumps (data )} \n \n "
115
129
yield "data: [DONE]\n \n "
116
130
117
131
def run (self ):
0 commit comments