55import re
66import time
77import uuid
8- from typing import Any , Dict , List , Literal , Optional , Union , AsyncGenerator
8+ from typing import Any , Dict , List , Literal , Optional , Union , AsyncGenerator , Generator
99
1010import fastapi
1111from fastapi .responses import StreamingResponse
12- from pydantic import BaseModel
13-
14- from beeai_framework .adapters .openai import OpenAIChatModel
15- from beeai_framework .adapters .watsonx import WatsonxChatModel
16- from beeai_framework .backend import (
17- ChatModelNewTokenEvent ,
18- ChatModelSuccessEvent ,
19- ChatModelErrorEvent ,
20- UserMessage ,
21- SystemMessage ,
22- AssistantMessage ,
23- )
12+ from pydantic import BaseModel , Field
13+ import openai
14+ from ibm_watsonx_ai import Credentials
15+ from ibm_watsonx_ai .foundation_models import ModelInference
16+ from fastapi .concurrency import run_in_threadpool
2417from beeai_server .api .dependencies import EnvServiceDependency
2518
2619
2720router = fastapi .APIRouter ()
2821
2922
23+ class FunctionCall (BaseModel ):
24+ name : str
25+ arguments : str
26+
27+
28+ class ToolCall (BaseModel ):
29+ id : str
30+ type : Literal ["function" ] = "function"
31+ function : FunctionCall
32+
33+
3034class ContentItem (BaseModel ):
3135 type : Literal ["text" ] = "text"
3236 text : str
@@ -35,11 +39,8 @@ class ContentItem(BaseModel):
3539class ChatCompletionMessage (BaseModel ):
3640 role : Literal ["system" , "user" , "assistant" , "function" , "tool" ] = "assistant"
3741 content : Union [str , List [ContentItem ]] = ""
38-
39- def get_text_content (self ) -> str :
40- if isinstance (self .content , str ):
41- return self .content
42- return "" .join (item .text for item in self .content if item .type == "text" )
42+ tool_calls : Optional [List [ToolCall ]] = None
43+ tool_call_id : Optional [str ] = None
4344
4445
4546class ChatCompletionRequest (BaseModel ):
@@ -56,11 +57,13 @@ class ChatCompletionRequest(BaseModel):
5657 logit_bias : Optional [Dict [str , float ]] = None
5758 user : Optional [str ] = None
5859 response_format : Optional [Dict [str , Any ]] = None
60+ tools : Optional [List [Dict [str , Any ]]] = None
61+ tool_choice : Optional [Union [str , Dict [str , Any ]]] = None
5962
6063
6164class ChatCompletionResponseChoice (BaseModel ):
6265 index : int = 0
63- message : ChatCompletionMessage = ChatCompletionMessage ( role = "assistant" , content = "" )
66+ message : ChatCompletionMessage
6467 finish_reason : Optional [str ] = None
6568
6669
@@ -73,9 +76,27 @@ class ChatCompletionResponse(BaseModel):
7376 choices : List [ChatCompletionResponseChoice ]
7477
7578
79+ class StreamFunctionCall (BaseModel ):
80+ name : Optional [str ] = None
81+ arguments : Optional [str ] = None
82+
83+
84+ class StreamToolCall (BaseModel ):
85+ index : int
86+ id : Optional [str ] = None
87+ type : Literal ["function" ] = "function"
88+ function : Optional [StreamFunctionCall ] = None
89+
90+
91+ class ChatCompletionStreamDelta (BaseModel ):
92+ role : Optional [Literal ["assistant" ]] = None
93+ content : Optional [str ] = None
94+ tool_calls : Optional [List [StreamToolCall ]] = None
95+
96+
7697class ChatCompletionStreamResponseChoice (BaseModel ):
7798 index : int = 0
78- delta : ChatCompletionMessage = ChatCompletionMessage ( )
99+ delta : ChatCompletionStreamDelta = Field ( default_factory = ChatCompletionStreamDelta )
79100 finish_reason : Optional [str ] = None
80101
81102
@@ -89,110 +110,121 @@ class ChatCompletionStreamResponse(BaseModel):
89110
90111
91112@router .post ("/chat/completions" )
92- async def create_chat_completion (
93- env_service : EnvServiceDependency ,
94- request : ChatCompletionRequest ,
95- ):
113+ async def create_chat_completion (env_service : EnvServiceDependency , request : ChatCompletionRequest ):
96114 env = await env_service .list_env ()
115+ llm_api_base = env ["LLM_API_BASE" ]
116+ llm_model = env ["LLM_MODEL" ]
97117
98- is_rits = re .match (r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$" , env [ "LLM_API_BASE" ] )
99- is_watsonx = re .match (r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$" , env [ "LLM_API_BASE" ] )
118+ is_rits = re .match (r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$" , llm_api_base )
119+ is_watsonx = re .match (r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$" , llm_api_base )
100120
101- llm = (
102- WatsonxChatModel (
103- model_id = env ["LLM_MODEL" ],
104- api_key = env ["LLM_API_KEY" ],
105- base_url = env ["LLM_API_BASE" ],
121+ messages = [msg .model_dump (exclude_none = True ) for msg in request .messages ]
122+
123+ if is_watsonx :
124+ watsonx_params = {}
125+ if isinstance (request .tool_choice , str ):
126+ watsonx_params ["tool_choice_option" ] = request .tool_choice
127+ elif isinstance (request .tool_choice , dict ):
128+ watsonx_params ["tool_choice" ] = request .tool_choice
129+
130+ model = ModelInference (
131+ model_id = llm_model ,
132+ credentials = Credentials (url = llm_api_base , api_key = env ["LLM_API_KEY" ]),
106133 project_id = env .get ("WATSONX_PROJECT_ID" ),
107134 space_id = env .get ("WATSONX_SPACE_ID" ),
135+ params = {
136+ "temperature" : request .temperature ,
137+ "max_new_tokens" : request .max_tokens ,
138+ "top_p" : request .top_p ,
139+ "presence_penalty" : request .presence_penalty ,
140+ "frequency_penalty" : request .frequency_penalty ,
141+ },
108142 )
109- if is_watsonx
110- else OpenAIChatModel (
111- env ["LLM_MODEL" ],
143+
144+ if request .stream :
145+ return StreamingResponse (
146+ _stream_watsonx_chat_completion (model , messages , request .tools , watsonx_params , request ),
147+ media_type = "text/event-stream" ,
148+ )
149+ else :
150+ response = await run_in_threadpool (model .chat , messages = messages , tools = request .tools , ** watsonx_params )
151+ choice = response ["choices" ][0 ]
152+ return ChatCompletionResponse (
153+ id = response .get ("id" , f"chatcmpl-{ uuid .uuid4 ()} " ),
154+ created = response .get ("created" , int (time .time ())),
155+ model = request .model ,
156+ choices = [
157+ ChatCompletionResponseChoice (
158+ message = ChatCompletionMessage (** choice ["message" ]),
159+ finish_reason = choice .get ("finish_reason" ),
160+ )
161+ ],
162+ )
163+ else :
164+ client = openai .AsyncOpenAI (
112165 api_key = env ["LLM_API_KEY" ],
113- base_url = env [ "LLM_API_BASE" ] ,
114- extra_headers = {"RITS_API_KEY" : env ["LLM_API_KEY" ]} if is_rits else {},
166+ base_url = llm_api_base ,
167+ default_headers = {"RITS_API_KEY" : env ["LLM_API_KEY" ]} if is_rits else {},
115168 )
116- )
117-
118- messages = [
119- UserMessage (msg .get_text_content ())
120- if msg .role == "user"
121- else SystemMessage (msg .get_text_content ())
122- if msg .role == "system"
123- else AssistantMessage (msg .get_text_content ())
124- for msg in request .messages
125- if msg .role in ["user" , "system" , "assistant" ]
126- ]
127-
128- if request .stream :
129- return StreamingResponse (stream_chat_completion (llm , messages , request ), media_type = "text/event-stream" )
130-
131- output = await llm .create (
132- messages = messages ,
133- temperature = request .temperature ,
134- maxTokens = request .max_tokens ,
135- response_format = request .response_format ,
136- )
137-
138- return ChatCompletionResponse (
139- id = f"chatcmpl-{ str (uuid .uuid4 ())} " ,
140- created = int (time .time ()),
141- model = request .model ,
142- choices = [
143- ChatCompletionResponseChoice (
144- message = ChatCompletionMessage (content = output .get_text_content ()),
145- finish_reason = output .finish_reason ,
169+ params = {** request .model_dump (exclude_none = True ), "model" : llm_model }
170+
171+ if request .stream :
172+ stream = await client .chat .completions .create (** params )
173+ return StreamingResponse (_stream_openai_chat_completion (stream ), media_type = "text/event-stream" )
174+ else :
175+ response = await client .chat .completions .create (** params )
176+ openai_choice = response .choices [0 ]
177+ return ChatCompletionResponse (
178+ id = response .id ,
179+ created = response .created ,
180+ model = response .model ,
181+ choices = [
182+ ChatCompletionResponseChoice (
183+ index = openai_choice .index ,
184+ message = ChatCompletionMessage (** openai_choice .message .model_dump ()),
185+ finish_reason = openai_choice .finish_reason ,
186+ )
187+ ],
146188 )
147- ],
148- )
149189
150190
151- async def stream_chat_completion (
152- llm : OpenAIChatModel ,
153- messages : List [Union [UserMessage , SystemMessage , AssistantMessage ]],
191+ def _stream_watsonx_chat_completion (
192+ model : ModelInference ,
193+ messages : List [Dict ],
194+ tools : Optional [List ],
195+ watsonx_params : Dict ,
154196 request : ChatCompletionRequest ,
155- ) -> AsyncGenerator [str , None ]:
197+ ) -> Generator [str , None , None ]:
198+ completion_id = f"chatcmpl-{ str (uuid .uuid4 ())} "
199+ created_time = int (time .time ())
156200 try :
157- completion_id = f"chatcmpl-{ str (uuid .uuid4 ())} "
158-
159- async for event , _ in llm .create (
160- messages = messages ,
161- stream = True ,
162- temperature = request .temperature ,
163- maxTokens = request .max_tokens ,
164- response_format = request .response_format ,
165- ):
166- if isinstance (event , ChatModelNewTokenEvent ):
167- yield f"""data: {
168- json .dumps (
169- ChatCompletionStreamResponse (
170- id = completion_id ,
171- created = int (time .time ()),
172- model = request .model ,
173- choices = [
174- ChatCompletionStreamResponseChoice (
175- delta = ChatCompletionMessage (content = event .value .get_text_content ())
176- )
177- ],
178- ).model_dump ()
179- )
180- } \n \n """
181- elif isinstance (event , ChatModelSuccessEvent ):
182- yield f"""data: {
183- json .dumps (
184- ChatCompletionStreamResponse (
185- id = completion_id ,
186- created = int (time .time ()),
187- model = request .model ,
188- choices = [ChatCompletionStreamResponseChoice (finish_reason = event .value .finish_reason )],
189- ).model_dump ()
201+ for chunk in model .chat_stream (messages = messages , tools = tools , ** watsonx_params ):
202+ choice = chunk ["choices" ][0 ]
203+ response_chunk = ChatCompletionStreamResponse (
204+ id = completion_id ,
205+ created = created_time ,
206+ model = request .model ,
207+ choices = [
208+ ChatCompletionStreamResponseChoice (
209+ delta = ChatCompletionStreamDelta (** choice .get ("delta" , {})),
210+ finish_reason = choice .get ("finish_reason" ),
190211 )
191- } \n \n """
192- return
193- elif isinstance (event , ChatModelErrorEvent ):
194- raise event .error
212+ ],
213+ )
214+ yield f"data: { response_chunk .model_dump_json (exclude_none = True )} \n \n "
215+ if choice .get ("finish_reason" ):
216+ break
217+ except Exception as e :
218+ yield f"data: { json .dumps ({'error' : {'message' : str (e ), 'type' : type (e ).__name__ }})} \n \n "
219+ finally :
220+ yield "data: [DONE]\n \n "
221+
222+
223+ async def _stream_openai_chat_completion (stream : AsyncGenerator ) -> AsyncGenerator [str , None ]:
224+ try :
225+ async for chunk in stream :
226+ yield f"data: { chunk .model_dump_json (exclude_none = True )} \n \n "
195227 except Exception as e :
196- yield f"data: { json .dumps (dict ( error = dict ( message = str (e ), type = type (e ).__name__ )) )} \n \n "
228+ yield f"data: { json .dumps ({ ' error' : { ' message' : str (e ), ' type' : type (e ).__name__ }} )} \n \n "
197229 finally :
198230 yield "data: [DONE]\n \n "
0 commit comments