Skip to content

Commit 41d6ca2

Browse files
authored
feat(beeai-server): add function calling support (#906)
1 parent adca7a8 commit 41d6ca2

File tree

4 files changed

+247
-111
lines changed

4 files changed

+247
-111
lines changed

apps/beeai-cli/src/beeai_cli/commands/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ async def _run_agent(
233233
)
234234
break
235235
case RunFailedEvent():
236-
console.print(format_error(event.run.error.code.value, event.run.error.message))
236+
console.print(format_error(str(event.run.error.code), event.run.error.message))
237237
case ArtifactEvent():
238238
if dump_files_path is None:
239239
continue

apps/beeai-server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"procrastinate==3.2.2",
4040
"sqlparse>=0.5.3",
4141
"pgvector>=0.4.1",
42+
"ibm-watsonx-ai>=1.3.28",
4243
]
4344

4445
[tool.ruff]

apps/beeai-server/src/beeai_server/api/routes/llm.py

Lines changed: 142 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,32 @@
55
import re
66
import time
77
import uuid
8-
from typing import Any, Dict, List, Literal, Optional, Union, AsyncGenerator
8+
from typing import Any, Dict, List, Literal, Optional, Union, AsyncGenerator, Generator
99

1010
import fastapi
1111
from fastapi.responses import StreamingResponse
12-
from pydantic import BaseModel
13-
14-
from beeai_framework.adapters.openai import OpenAIChatModel
15-
from beeai_framework.adapters.watsonx import WatsonxChatModel
16-
from beeai_framework.backend import (
17-
ChatModelNewTokenEvent,
18-
ChatModelSuccessEvent,
19-
ChatModelErrorEvent,
20-
UserMessage,
21-
SystemMessage,
22-
AssistantMessage,
23-
)
12+
from pydantic import BaseModel, Field
13+
import openai
14+
from ibm_watsonx_ai import Credentials
15+
from ibm_watsonx_ai.foundation_models import ModelInference
16+
from fastapi.concurrency import run_in_threadpool
2417
from beeai_server.api.dependencies import EnvServiceDependency
2518

2619

2720
router = fastapi.APIRouter()
2821

2922

23+
class FunctionCall(BaseModel):
24+
name: str
25+
arguments: str
26+
27+
28+
class ToolCall(BaseModel):
29+
id: str
30+
type: Literal["function"] = "function"
31+
function: FunctionCall
32+
33+
3034
class ContentItem(BaseModel):
3135
type: Literal["text"] = "text"
3236
text: str
@@ -35,11 +39,8 @@ class ContentItem(BaseModel):
3539
class ChatCompletionMessage(BaseModel):
3640
role: Literal["system", "user", "assistant", "function", "tool"] = "assistant"
3741
content: Union[str, List[ContentItem]] = ""
38-
39-
def get_text_content(self) -> str:
40-
if isinstance(self.content, str):
41-
return self.content
42-
return "".join(item.text for item in self.content if item.type == "text")
42+
tool_calls: Optional[List[ToolCall]] = None
43+
tool_call_id: Optional[str] = None
4344

4445

4546
class ChatCompletionRequest(BaseModel):
@@ -56,11 +57,13 @@ class ChatCompletionRequest(BaseModel):
5657
logit_bias: Optional[Dict[str, float]] = None
5758
user: Optional[str] = None
5859
response_format: Optional[Dict[str, Any]] = None
60+
tools: Optional[List[Dict[str, Any]]] = None
61+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
5962

6063

6164
class ChatCompletionResponseChoice(BaseModel):
6265
index: int = 0
63-
message: ChatCompletionMessage = ChatCompletionMessage(role="assistant", content="")
66+
message: ChatCompletionMessage
6467
finish_reason: Optional[str] = None
6568

6669

@@ -73,9 +76,27 @@ class ChatCompletionResponse(BaseModel):
7376
choices: List[ChatCompletionResponseChoice]
7477

7578

79+
class StreamFunctionCall(BaseModel):
80+
name: Optional[str] = None
81+
arguments: Optional[str] = None
82+
83+
84+
class StreamToolCall(BaseModel):
85+
index: int
86+
id: Optional[str] = None
87+
type: Literal["function"] = "function"
88+
function: Optional[StreamFunctionCall] = None
89+
90+
91+
class ChatCompletionStreamDelta(BaseModel):
92+
role: Optional[Literal["assistant"]] = None
93+
content: Optional[str] = None
94+
tool_calls: Optional[List[StreamToolCall]] = None
95+
96+
7697
class ChatCompletionStreamResponseChoice(BaseModel):
7798
index: int = 0
78-
delta: ChatCompletionMessage = ChatCompletionMessage()
99+
delta: ChatCompletionStreamDelta = Field(default_factory=ChatCompletionStreamDelta)
79100
finish_reason: Optional[str] = None
80101

81102

@@ -89,110 +110,121 @@ class ChatCompletionStreamResponse(BaseModel):
89110

90111

91112
@router.post("/chat/completions")
92-
async def create_chat_completion(
93-
env_service: EnvServiceDependency,
94-
request: ChatCompletionRequest,
95-
):
113+
async def create_chat_completion(env_service: EnvServiceDependency, request: ChatCompletionRequest):
96114
env = await env_service.list_env()
115+
llm_api_base = env["LLM_API_BASE"]
116+
llm_model = env["LLM_MODEL"]
97117

98-
is_rits = re.match(r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$", env["LLM_API_BASE"])
99-
is_watsonx = re.match(r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$", env["LLM_API_BASE"])
118+
is_rits = re.match(r"^https://[a-z0-9.-]+\.rits\.fmaas\.res\.ibm.com/.*$", llm_api_base)
119+
is_watsonx = re.match(r"^https://[a-z0-9.-]+\.ml\.cloud\.ibm\.com.*?$", llm_api_base)
100120

101-
llm = (
102-
WatsonxChatModel(
103-
model_id=env["LLM_MODEL"],
104-
api_key=env["LLM_API_KEY"],
105-
base_url=env["LLM_API_BASE"],
121+
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
122+
123+
if is_watsonx:
124+
watsonx_params = {}
125+
if isinstance(request.tool_choice, str):
126+
watsonx_params["tool_choice_option"] = request.tool_choice
127+
elif isinstance(request.tool_choice, dict):
128+
watsonx_params["tool_choice"] = request.tool_choice
129+
130+
model = ModelInference(
131+
model_id=llm_model,
132+
credentials=Credentials(url=llm_api_base, api_key=env["LLM_API_KEY"]),
106133
project_id=env.get("WATSONX_PROJECT_ID"),
107134
space_id=env.get("WATSONX_SPACE_ID"),
135+
params={
136+
"temperature": request.temperature,
137+
"max_new_tokens": request.max_tokens,
138+
"top_p": request.top_p,
139+
"presence_penalty": request.presence_penalty,
140+
"frequency_penalty": request.frequency_penalty,
141+
},
108142
)
109-
if is_watsonx
110-
else OpenAIChatModel(
111-
env["LLM_MODEL"],
143+
144+
if request.stream:
145+
return StreamingResponse(
146+
_stream_watsonx_chat_completion(model, messages, request.tools, watsonx_params, request),
147+
media_type="text/event-stream",
148+
)
149+
else:
150+
response = await run_in_threadpool(model.chat, messages=messages, tools=request.tools, **watsonx_params)
151+
choice = response["choices"][0]
152+
return ChatCompletionResponse(
153+
id=response.get("id", f"chatcmpl-{uuid.uuid4()}"),
154+
created=response.get("created", int(time.time())),
155+
model=request.model,
156+
choices=[
157+
ChatCompletionResponseChoice(
158+
message=ChatCompletionMessage(**choice["message"]),
159+
finish_reason=choice.get("finish_reason"),
160+
)
161+
],
162+
)
163+
else:
164+
client = openai.AsyncOpenAI(
112165
api_key=env["LLM_API_KEY"],
113-
base_url=env["LLM_API_BASE"],
114-
extra_headers={"RITS_API_KEY": env["LLM_API_KEY"]} if is_rits else {},
166+
base_url=llm_api_base,
167+
default_headers={"RITS_API_KEY": env["LLM_API_KEY"]} if is_rits else {},
115168
)
116-
)
117-
118-
messages = [
119-
UserMessage(msg.get_text_content())
120-
if msg.role == "user"
121-
else SystemMessage(msg.get_text_content())
122-
if msg.role == "system"
123-
else AssistantMessage(msg.get_text_content())
124-
for msg in request.messages
125-
if msg.role in ["user", "system", "assistant"]
126-
]
127-
128-
if request.stream:
129-
return StreamingResponse(stream_chat_completion(llm, messages, request), media_type="text/event-stream")
130-
131-
output = await llm.create(
132-
messages=messages,
133-
temperature=request.temperature,
134-
maxTokens=request.max_tokens,
135-
response_format=request.response_format,
136-
)
137-
138-
return ChatCompletionResponse(
139-
id=f"chatcmpl-{str(uuid.uuid4())}",
140-
created=int(time.time()),
141-
model=request.model,
142-
choices=[
143-
ChatCompletionResponseChoice(
144-
message=ChatCompletionMessage(content=output.get_text_content()),
145-
finish_reason=output.finish_reason,
169+
params = {**request.model_dump(exclude_none=True), "model": llm_model}
170+
171+
if request.stream:
172+
stream = await client.chat.completions.create(**params)
173+
return StreamingResponse(_stream_openai_chat_completion(stream), media_type="text/event-stream")
174+
else:
175+
response = await client.chat.completions.create(**params)
176+
openai_choice = response.choices[0]
177+
return ChatCompletionResponse(
178+
id=response.id,
179+
created=response.created,
180+
model=response.model,
181+
choices=[
182+
ChatCompletionResponseChoice(
183+
index=openai_choice.index,
184+
message=ChatCompletionMessage(**openai_choice.message.model_dump()),
185+
finish_reason=openai_choice.finish_reason,
186+
)
187+
],
146188
)
147-
],
148-
)
149189

150190

151-
async def stream_chat_completion(
152-
llm: OpenAIChatModel,
153-
messages: List[Union[UserMessage, SystemMessage, AssistantMessage]],
191+
def _stream_watsonx_chat_completion(
192+
model: ModelInference,
193+
messages: List[Dict],
194+
tools: Optional[List],
195+
watsonx_params: Dict,
154196
request: ChatCompletionRequest,
155-
) -> AsyncGenerator[str, None]:
197+
) -> Generator[str, None, None]:
198+
completion_id = f"chatcmpl-{str(uuid.uuid4())}"
199+
created_time = int(time.time())
156200
try:
157-
completion_id = f"chatcmpl-{str(uuid.uuid4())}"
158-
159-
async for event, _ in llm.create(
160-
messages=messages,
161-
stream=True,
162-
temperature=request.temperature,
163-
maxTokens=request.max_tokens,
164-
response_format=request.response_format,
165-
):
166-
if isinstance(event, ChatModelNewTokenEvent):
167-
yield f"""data: {
168-
json.dumps(
169-
ChatCompletionStreamResponse(
170-
id=completion_id,
171-
created=int(time.time()),
172-
model=request.model,
173-
choices=[
174-
ChatCompletionStreamResponseChoice(
175-
delta=ChatCompletionMessage(content=event.value.get_text_content())
176-
)
177-
],
178-
).model_dump()
179-
)
180-
}\n\n"""
181-
elif isinstance(event, ChatModelSuccessEvent):
182-
yield f"""data: {
183-
json.dumps(
184-
ChatCompletionStreamResponse(
185-
id=completion_id,
186-
created=int(time.time()),
187-
model=request.model,
188-
choices=[ChatCompletionStreamResponseChoice(finish_reason=event.value.finish_reason)],
189-
).model_dump()
201+
for chunk in model.chat_stream(messages=messages, tools=tools, **watsonx_params):
202+
choice = chunk["choices"][0]
203+
response_chunk = ChatCompletionStreamResponse(
204+
id=completion_id,
205+
created=created_time,
206+
model=request.model,
207+
choices=[
208+
ChatCompletionStreamResponseChoice(
209+
delta=ChatCompletionStreamDelta(**choice.get("delta", {})),
210+
finish_reason=choice.get("finish_reason"),
190211
)
191-
}\n\n"""
192-
return
193-
elif isinstance(event, ChatModelErrorEvent):
194-
raise event.error
212+
],
213+
)
214+
yield f"data: {response_chunk.model_dump_json(exclude_none=True)}\n\n"
215+
if choice.get("finish_reason"):
216+
break
217+
except Exception as e:
218+
yield f"data: {json.dumps({'error': {'message': str(e), 'type': type(e).__name__}})}\n\n"
219+
finally:
220+
yield "data: [DONE]\n\n"
221+
222+
223+
async def _stream_openai_chat_completion(stream: AsyncGenerator) -> AsyncGenerator[str, None]:
224+
try:
225+
async for chunk in stream:
226+
yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
195227
except Exception as e:
196-
yield f"data: {json.dumps(dict(error=dict(message=str(e), type=type(e).__name__)))}\n\n"
228+
yield f"data: {json.dumps({'error': {'message': str(e), 'type': type(e).__name__}})}\n\n"
197229
finally:
198230
yield "data: [DONE]\n\n"

0 commit comments

Comments
 (0)