Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
from openai import AsyncOpenAI
from agents import Agent, OpenAIChatCompletionsModel, Runner, RunResult
from agents.tool import FunctionTool
from openai import AsyncOpenAI # type: ignore
from agents import Agent, OpenAIChatCompletionsModel, Runner, RunResult, set_tracing_disabled # type: ignore
from agents.tool import FunctionTool # type: ignore
from mcp.types import Tool
from core.tools import ToolManager
from mcp_client import MCPClient

set_tracing_disabled(True)


async def convert_to_sdk_tool(tools_schema: list[Tool], mcp_clients: dict[str, MCPClient]) -> list[FunctionTool]:
converted_tools = []
for tool in tools_schema:
Expand All @@ -31,7 +34,7 @@ class AgentService:
def __init__(self, model: str, api_key: str, base_url: str | None = None, clients=None):
self.model = model
self.api_key = api_key
self.messages = []
self.messages = [] # type: ignore

self.client = AsyncOpenAI(
api_key=api_key,
Expand All @@ -44,7 +47,8 @@ def __init__(self, model: str, api_key: str, base_url: str | None = None, client
model=OpenAIChatCompletionsModel(
model=model,
openai_client=self.client
),
)

)

async def chat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ def get_completions(self, document, complete_event):


class CliApp:
session: PromptSession

def __init__(self, agent: CliChat):
self.agent = agent
self.resources = []
self.prompts = []
self.resources: list[str] = []
self.prompts: list = []

self.completer = UnifiedCompleter()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,35 @@ def __init__(

self.doc_client: MCPClient = doc_client

async def run(self, query: str) -> str:
"""Override run method to process resources before sending to agent"""
# Process query to extract resources and inject document content
await self._process_query(query)

# Call agent service without passing query since _process_query already added to messages
response = await self.agent_serve.chat(
query="", # Empty because _process_query already added enhanced prompt to messages
mcp_clients=self.clients,
)

return response.final_output

async def list_prompts(self) -> list[Prompt]:
return await self.doc_client.list_prompts()

async def list_docs_ids(self) -> list[str]:
return await self.doc_client.read_resource("docs://documents")
resource = await self.doc_client.read_resource("docs://documents")
# The docs://documents resource returns a JSON list
if isinstance(resource, list):
return resource
return []

async def get_doc_content(self, doc_id: str) -> str:
return await self.doc_client.read_resource(f"docs://documents/{doc_id}")
resource = await self.doc_client.read_resource(f"docs://{doc_id}")
# Extract text from the resource object
if hasattr(resource, 'text'):
return resource.text
return str(resource)

async def get_prompt(
self, command: str, doc_id: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ class ToolManager:
@classmethod
async def get_all_tools(cls, clients: dict[str, MCPClient]) -> list[Tool]:
"""Gets all tools from the provided clients."""
tools = []
all_tools = []
for client in clients.values():
tool_models = await client.list_tools()
tools = tool_models if tool_models else []
return tools
if tool_models: # Only extend if tool_models is not empty
all_tools.extend(tool_models)
return all_tools

@classmethod
async def _find_client_with_tool(
Expand All @@ -30,7 +31,7 @@ async def _find_client_with_tool(
@classmethod
def execute_tool_dynamically(cls, tool_name, mcp_client: MCPClient):
"""Execute a simulated database query."""
async def execute_tool(ctx: ToolContext, args: str) -> CallToolResult:
async def execute_tool(ctx: ToolContext, args: str):
parsed_args = json.loads(args)
result = await mcp_client.call_tool(tool_name, parsed_args)
return result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import sys
import os
import subprocess
import time
from dotenv import load_dotenv, find_dotenv
from contextlib import AsyncExitStack

Expand All @@ -26,42 +28,65 @@
)


async def start_mcp_server():
"""Start the MCP server in a separate process"""
print("Starting MCP server...")
process = subprocess.Popen(
["uv", "run", "mcp_server.py"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Give the server time to start
time.sleep(2)
return process


async def main():
server_scripts = sys.argv[1:]
clients = {}

# command, args = ("uv", ["run", "mcp_server.py"])
server_url = "http://localhost:8000/mcp/"
# Start the MCP server
server_process = await start_mcp_server()

try:
# Connect to the HTTP-based MCP server
server_url = "http://localhost:8000/mcp/"

async with AsyncExitStack() as stack:
doc_client = await stack.enter_async_context(
MCPClient(server_url=server_url)
)
clients["doc_client"] = doc_client

async with AsyncExitStack() as stack:
doc_client = await stack.enter_async_context(
MCPClient(server_url=server_url)
)
clients["doc_client"] = doc_client
for i, server_script in enumerate(server_scripts):
client_id = f"client_{i}_{server_script}"
client = await stack.enter_async_context(
MCPClient(command="uv", args=["run", server_script])
)
clients[client_id] = client

for i, server_script in enumerate(server_scripts):
client_id = f"client_{i}_{server_script}"
client = await stack.enter_async_context(
MCPClient(command="uv", args=["run", server_script])
agent_service = AgentService(
model=llm_model,
api_key=llm_api_key,
base_url=llm_base_url,
clients=clients
)
clients[client_id] = client

agent_service = AgentService(
model=llm_model,
api_key=llm_api_key,
base_url=llm_base_url,
clients=clients
)

chat = CliChat(
doc_client=doc_client,
clients=clients,
agent_serve=agent_service,
)

cli = CliApp(chat)
await cli.initialize()
await cli.run()

chat = CliChat(
doc_client=doc_client,
clients=clients,
agent_serve=agent_service,
)

cli = CliApp(chat)
await cli.initialize()
await cli.run()

finally:
# Clean up the server process
if server_process:
server_process.terminate()
server_process.wait()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,38 @@ async def main():
) as _client:
# Example usage:
# Retrieve and print available tools to verify the client implementation.
print("Listing tools...")
tools = await _client.list_tools()
print("Available Tools:", tools)
for tool in tools:
print(f"Tool: {tool.name}")
print(f"Description: {tool.description}")
if tool.inputSchema:
print("Parameters:")
properties = tool.inputSchema.get('properties', {})
required = tool.inputSchema.get('required', [])
for param_name, param_info in properties.items():
param_type = param_info.get('type', 'unknown')
param_desc = param_info.get('description', 'No description')
required_str = " (required)" if param_name in required else " (optional)"
print(f" - {param_name}: {param_type}{required_str} - {param_desc}")

print("-" * 50) # Add separator between tools

print("=" * 50) # Add separator after tools section


# Example usage:
# Test reading the document list
print("Reading document list...")
doc_list = await _client.read_resource("docs://documents")
print("Document List:")
print(f" Documents: {doc_list}")

# Test reading individual documents
print("\nReading individual documents...")
for doc_id in doc_list:
doc_contents = await _client.read_resource(f"docs://{doc_id}")
print(f" {doc_id}: {doc_contents.text}")

if __name__ == "__main__":
if sys.platform == "win32":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
def read_document(
doc_id: str = Field(description="Id of the document to read")
):
print(f"Reading document tool called with {doc_id}...")
if doc_id not in docs:
raise ValueError(f"Doc with id {doc_id} not found")

Expand All @@ -37,6 +38,7 @@ def edit_document(
new_str: str = Field(
description="The new text to insert in place of the old text.")
):
print(f"Editing document tool called with {doc_id}...")
if doc_id not in docs:
raise ValueError(f"Doc with id {doc_id} not found")

Expand All @@ -49,6 +51,7 @@ def edit_document(
mime_type="application/json"
)
def list_docs() -> list[str]:
print(f"Listing resources called")
return list(docs.keys())

# TODO: Write a resource to return the contents of a particular doc
Expand All @@ -57,10 +60,17 @@ def list_docs() -> list[str]:
mime_type="text/plain"
)
def get_doc(doc_id: str) -> str:
print(f"Getting document resource called with {doc_id}")
return docs[doc_id]

# TODO: Write a prompt to rewrite a doc in markdown format
# TODO: Write a prompt to summarize a doc


mcp_app = mcp.streamable_http_app()


if __name__ == "__main__":
import uvicorn
print("Starting MCP server...")
uvicorn.run("mcp_server:mcp_app", host="0.0.0.0", port=8000, reload=True)