Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"request": "launch",
"django": true,
"module": "mcp_bridge.main",
"pythonArgs": ["-Xutf8"]
}
]
}
}
127 changes: 113 additions & 14 deletions mcp_bridge/openai_clients/streamChatCompletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
CreateChatCompletionRequest,
CreateChatCompletionStreamResponse,
Function1,
FinishReason1,
ChatCompletionToolChoiceOption1,
ChatCompletionToolChoiceOption,
)
from .utils import call_tool, chat_completion_add_tools
from mcp_bridge.models import SSEData
Expand All @@ -15,8 +18,12 @@
from mcp_bridge.tool_mappers import mcp2openai
from loguru import logger
from httpx_sse import aconnect_sse
import datetime
import os

from sse_starlette.sse import EventSourceResponse, ServerSentEvent
import json
import traceback


async def streaming_chat_completions(request: CreateChatCompletionRequest):
Expand All @@ -33,24 +40,52 @@ async def streaming_chat_completions(request: CreateChatCompletionRequest):
logger.error(e)


def validate_if_json_object_parsable(content: str):
try:
json.loads(content)
return True
except ValueError:
return False


def salvage_parsable_json_object(content: str):
content = content.strip()
for i in range(0, len(content)):
snippet = content[: len(content) - i]
if validate_if_json_object_parsable(snippet):
return snippet


async def chat_completions(request: CreateChatCompletionRequest):
"""performs a chat completion using the inference server"""

request.stream = True

request = await chat_completion_add_tools(request)
request = await chat_completion_add_tools(
request
) # Date: 2025/01/27 ChatMCP clear tools after first tool call.

fully_done = False
while not fully_done:
# json_data = request.model_dump_json(
# exclude_defaults=True, exclude_none=True, exclude_unset=True
# )
if request.tools:
request.tool_choice = ChatCompletionToolChoiceOption(
root=ChatCompletionToolChoiceOption1.auto
)

json_data = json.dumps(request.model_dump(
exclude_defaults=True, exclude_none=True, exclude_unset=True
))
json_data = json.dumps(
request.model_dump(
exclude_defaults=True,
exclude_none=True,
exclude_unset=True,
),
indent=4,
ensure_ascii=False,
)

# logger.debug(json_data)
logger.debug("Request JSON:\n%s" % json_data) # empty?

last: Optional[CreateChatCompletionStreamResponse] = None # last message

Expand All @@ -63,19 +98,40 @@ async def chat_completions(request: CreateChatCompletionRequest):
async with aconnect_sse(
client, "post", "/chat/completions", content=json_data
) as event_source:

# check if the content type is correct because the aiter_sse method
# will raise an exception if the content type is not correct
if "Content-Type" in event_source.response.headers:
if "Content-Type" in event_source.response.headers: # error here.
content_type = event_source.response.headers["Content-Type"]
if "text/event-stream" not in content_type:
logger.error(f"Unexpected Content-Type: {content_type}")
error_data = await event_source.response.aread()
logger.error(f"Request URL: {event_source.response.url}")
logger.error(f"Request Data: {json_data}")
logger.error(f"Response Status: {event_source.response.status_code}")
logger.error(f"Response Data: {error_data.decode(event_source.response.encoding or 'utf-8')}")
raise HTTPException(status_code=500, detail="Unexpected Content-Type")
log_dir = os.path.join(os.getcwd(), "logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
request_data_path = f"{log_dir}/request_data_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
request_data_formatted = json.dumps(
json.loads(json_data), indent=4, ensure_ascii=False
)
with open(request_data_path, "w+") as f:
f.write(request_data_formatted)
logger.error(f"Request Data saved to: {request_data_path}")
logger.error(f"Request Data:\n{request_data_formatted}")
logger.error(
f"Response Status: {event_source.response.status_code}"
)
error_data_decoded = error_data.decode(
event_source.response.encoding or "utf-8"
)
error_data_path = f"{log_dir}/error_data_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
logger.error(f"Response Data saved to: {error_data_path}")
logger.error(f"Response Data:\n{error_data_decoded}")
with open(error_data_path, "w+") as f:
f.write(error_data_decoded)
raise HTTPException(
status_code=500, detail="Unexpected Content-Type"
)

# iterate over the SSE stream
async for sse in event_source.aiter_sse():
Expand All @@ -95,18 +151,26 @@ async def chat_completions(request: CreateChatCompletionRequest):

# for some reason openrouter uses uppercase for finish_reason
try:
data['choices'][0]['finish_reason'] = data['choices'][0]['finish_reason'].lower() # type: ignore
mjson_data = json.loads(data)

# Date: 2025/01/26 failed to lowercase finish_reason: string indices must be integers, not 'str'
if mjson_data["choices"][0].keys().__contains__("finish_reason"): # type: ignore
mjson_data["choices"][0]["finish_reason"] = mjson_data["choices"][0]["finish_reason"].lower() # type: ignore

data = json.dumps(mjson_data, ensure_ascii=False)
except Exception as e:
traceback.print_exc()
logger.debug(f"failed to lowercase finish_reason: {e}")

try:
parsed_data = CreateChatCompletionStreamResponse.model_validate_json(
data
parsed_data = (
CreateChatCompletionStreamResponse.model_validate_json(data)
)
except Exception as e:
logger.debug(data)
raise e


# add the delta to the response content
content = parsed_data.choices[0].delta.content
content = content if content is not None else ""
Expand Down Expand Up @@ -139,7 +203,9 @@ async def chat_completions(request: CreateChatCompletionRequest):
tool_call_id = id if tool_call_id == "" else tool_call_id

arg = parsed_data.choices[0].delta.tool_calls[0].function.arguments

tool_call_json += arg if arg is not None else ""
# Date: 2025/01/26 validate the tool call json.

# forward SSE messages to the client
logger.debug(f"{should_forward=}")
Expand All @@ -151,6 +217,27 @@ async def chat_completions(request: CreateChatCompletionRequest):
# save the last message
last = parsed_data

if tool_call_json:
if tool_call_json.strip().startswith("{"):
if validate_if_json_object_parsable(tool_call_json):
logger.debug(
f"tool call json '{tool_call_json}' is parsable now."
)
logger.debug("exiting message receive loop")
last.choices[0].finish_reason = FinishReason1.tool_calls
break
salvaged_json_object = salvage_parsable_json_object(
tool_call_json
)
if salvaged_json_object:
tool_call_json = salvaged_json_object
logger.debug(
f"tool call json '{tool_call_json}' is salvagable now."
)
logger.debug("salvaged json content:", tool_call_json)
logger.debug("exiting message receive loop")
last.choices[0].finish_reason = FinishReason1.tool_calls
break
# ideally we should check this properly
assert last is not None
assert last.choices[0].finish_reason is not None
Expand All @@ -165,6 +252,9 @@ async def chat_completions(request: CreateChatCompletionRequest):
f"{tool_call_name=} {tool_call_json=}"
) # this should not be error but its easier to debug

logger.debug("clearing tool contexts to prevent tool call loops")
request.tools = None

# add received message to the history
msg = ChatCompletionRequestMessage(
role="assistant",
Expand All @@ -181,6 +271,7 @@ async def chat_completions(request: CreateChatCompletionRequest):

#### MOST OF THIS IS COPY PASTED FROM CHAT_COMPLETIONS
# FIXME: this can probably be done in parallel using asyncio gather
# Date: 2025/01/26 decoding error?
tool_call_result = await call_tool(tool_call_name, tool_call_json)
if tool_call_result is None:
continue
Expand All @@ -207,6 +298,14 @@ async def chat_completions(request: CreateChatCompletionRequest):
)
)

# Date: 2025/01/26 crucial! we have to ensure the llm does not end up with infinite loop.

# request.messages.append(
# ChatCompletionRequestMessage.model_validate(
# {"role": "user", "content": "Do you consider you have done enough tool calls? If not, please continue the rest of the tool calls. If yes, please respond to the user and end the conversation."}
# )
# )

logger.debug("sending next iteration of chat completion request")

# when done, send the final event
Expand Down
15 changes: 11 additions & 4 deletions mcp_bridge/openai_clients/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@
from lmos_openai_types import CreateChatCompletionRequest
import mcp.types
import json
import traceback

from mcp_bridge.mcp_clients.McpClientManager import ClientManager
from mcp_bridge.tool_mappers import mcp2openai


async def chat_completion_add_tools(request: CreateChatCompletionRequest):
request.tools = []
logger.info("adding tools to request")

for _, session in ClientManager.get_clients():
# if session is None, then the client is not running
if session.session is None:
logger.error(f"session is `None` for {session.name}")
logger.error(f"session is `None` for {session.name}") # Date:2025/01/25 why not running?
continue

logger.debug(f"session ready for {session.name}")
tools = await session.session.list_tools()
for tool in tools.tools:
request.tools.append(mcp2openai(tool))


if request.tools == []:
logger.info("this request loads no tools")
# raise Exception("no tools found. unable to initiate chat completion.")
request.tools = None
return request


Expand All @@ -42,9 +48,10 @@ async def call_tool(
return None

try:
tool_call_args = json.loads(tool_call_json)
tool_call_args = json.loads(tool_call_json) # Date: 2025/01/26 cannot load this tool call json?
except json.JSONDecodeError:
logger.error(f"failed to decode json for {tool_call_name}")
traceback.print_exc()
return None

return await session.call_tool(tool_call_name, tool_call_args, timeout)