Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ def __init__(self):

def add_tool(self, tool: Tool) -> None:
self.chat.add_tool(tool)

def enable_tool(self, tool_name: str) -> None:
try:
self.chat.enable_tool(tool_name)
except Exception as e:
return False
return True

def disable_tool(self, tool_name: str) -> None:
Comment on lines +59 to +66
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type should be annotated as -> bool to reflect that this method returns True/False, improving type clarity.

Suggested change
def enable_tool(self, tool_name: str) -> None:
try:
self.chat.enable_tool(tool_name)
except Exception as e:
return False
return True
def disable_tool(self, tool_name: str) -> None:
def enable_tool(self, tool_name: str) -> bool:
try:
self.chat.enable_tool(tool_name)
except Exception as e:
return False
return True
def disable_tool(self, tool_name: str) -> bool:

Copilot uses AI. Check for mistakes.
try:
self.chat.disable_tool(tool_name)
except Exception as e:
return False
return True

def get_tools(self) -> list:
return self.chat.get_tools()

async def process_query(self, user_prompt: str) -> str:
user_role = {"role": "user", "content": user_prompt}
Expand All @@ -68,10 +85,13 @@ async def process_query(self, user_prompt: str) -> str:
assistant_message = choices[0].get("message", {})
messages.append(assistant_message)

tools_used = set()
# Handle the case where tool_calls might be missing or not a list
while assistant_message.get("tool_calls"):
await self.chat.process_tool_calls(assistant_message, messages.append)

used_tools = await self.chat.process_tool_calls(assistant_message, messages.append)
for tool in used_tools:
tools_used.add(tool)

response = await self.chat.send_messages(messages)
if not (response and response.get("choices", None)):
break
Expand All @@ -85,7 +105,7 @@ async def process_query(self, user_prompt: str) -> str:
self.history.append(assistant_message)

pretty_print("History", self.history)
return result
return result, tools_used


# Global agent instance for backwards compatibility
Expand Down
4 changes: 2 additions & 2 deletions src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

from api.routes import router, agent, session_manager
from api.routes import router, agent_instance, session_manager
mimetypes.add_type("application/javascript", ".js")

@asynccontextmanager
Expand All @@ -15,7 +15,7 @@ async def lifespan(app: FastAPI):
config_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config', 'mcp.json')
await session_manager.discovery(config_path)
for tool in session_manager.tools:
agent.add_tool(tool)
agent_instance.add_tool(tool)
print("Agent and tools initialized successfully.")
except Exception as e:
print(f"Error initializing MCP tools: {str(e)}")
Expand Down
24 changes: 24 additions & 0 deletions src/api/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
from pydantic import BaseModel
from typing import List, Optional

class QueryRequest(BaseModel):
query: str


class QueryResponse(BaseModel):
response: str
used_tools: Optional[List[str]] = []


class ToolInfo(BaseModel):
name: str
description: str
enabled: bool
parameters: dict


class ToolsListResponse(BaseModel):
tools: List[ToolInfo]


class ToolToggleRequest(BaseModel):
tool_name: str
enabled: bool


class ToolToggleResponse(BaseModel):
tool_name: str
enabled: bool
message: str
59 changes: 53 additions & 6 deletions src/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,66 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException

from api.auth import get_api_key
from api.models import QueryRequest, QueryResponse
from agent import Agent
from api.models import QueryRequest, QueryResponse, ToolsListResponse, ToolToggleRequest, ToolToggleResponse, ToolInfo
from agent import agent_instance
from core.mcp.sessions_manager import MCPSessionManager

agent = Agent()
session_manager = MCPSessionManager()

router = APIRouter(prefix="/api", dependencies=[Depends(get_api_key)])

@router.post("/ask", response_model=QueryResponse)
async def ask_agent(request: QueryRequest) -> QueryResponse:
try:
response = await agent.process_query(request.query)
return QueryResponse(response=response)
response, used_tools = await agent_instance.process_query(request.query)
return QueryResponse(
response=response,
used_tools=list(used_tools)
)
except Exception as e:
return QueryResponse(response=f"Sorry, I encountered an error: {str(e)}")


@router.get("/tools", response_model=ToolsListResponse)
async def list_tools() -> ToolsListResponse:
try:
tools_info = agent_instance.get_tools()
tools = [
ToolInfo(
name=info.name,
description=info.description,
enabled=info.enabled,
parameters=info.parameters
)
for info in tools_info
]
return ToolsListResponse(tools=tools)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing tools: {str(e)}")


@router.post("/tools/toggle", response_model=ToolToggleResponse)
async def toggle_tool(request: ToolToggleRequest) -> ToolToggleResponse:
try:
if request.enabled:
success = agent_instance.enable_tool(request.tool_name)
action = "enabled"
else:
success = agent_instance.disable_tool(request.tool_name)
action = "disabled"

if not success:
raise HTTPException(
status_code=404,
detail=f"Tool '{request.tool_name}' not found"
)

return ToolToggleResponse(
tool_name=request.tool_name,
enabled=request.enabled,
message=f"Tool '{request.tool_name}' has been {action}"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling tool: {str(e)}")
34 changes: 31 additions & 3 deletions src/core/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,35 @@ class Chat:
def __init__(self, tool_list: List[Tool] = []):
self.chat_client: ChatClient = ChatClient()
self.tool_map = {tool.name: tool for tool in tool_list}
self.tools = [tool.define() for tool in tool_list]
self.tools: List[Tool] = [tool for tool in tool_list]

def add_tool(self, tool: Tool) -> None:
self.tool_map[tool.name] = tool
self.tools.append(tool.define())
self.tools.append(tool)
self.tools = list(set(self.tools)) # Ensure tools are unique

def get_tools(self) -> List[Dict[str, Any]]:
return self.tools

def enable_tool(self, tool_name: str) -> None:
self._set_tool_state(tool_name, active=True)

def disable_tool(self, tool_name: str) -> None:
self._set_tool_state(tool_name, active=False)

def _set_tool_state(self, tool_name: str, active = True) -> None:
for tool in self.tools:
print(f"Checking tool: {tool.name} against {tool_name} ")
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This debug print may clutter logs in production. Consider removing it or gating it behind a verbose/debug flag.

Suggested change
print(f"Checking tool: {tool.name} against {tool_name} ")
if is_debug():
print(f"Checking tool: {tool.name} against {tool_name} ")

Copilot uses AI. Check for mistakes.
if tool.name != tool_name:
continue

tool.disable()
if active:
tool.enable()
Comment on lines +43 to +45
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The disable-then-enable pattern is redundant when activating a tool. Consider simplifying to if active: tool.enable() else: tool.disable() for clarity.

Suggested change
tool.disable()
if active:
tool.enable()
if active:
tool.enable()
else:
tool.disable()

Copilot uses AI. Check for mistakes.

return

raise ValueError(f"Tool '{tool_name}' not found in the chat tools.")

@classmethod
def create(cls, tool_list = []) -> 'Chat':
Expand All @@ -45,7 +69,7 @@ async def send_messages(
messages=messages,
temperature=0.7,
max_tokens=32000,
tools=self.tools
tools=[tool.define() for tool in self.tools if tool.enabled],
)

return resp
Expand All @@ -57,6 +81,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
print(colorize_text(f"\n{hr} <{name}> {hr}\n", "yellow"))

# Safely get tool_calls - convert None to empty list to handle the case when tool_calls is None
tools_used = []
tool_calls = response.get("tool_calls", [])
for tool_call in tool_calls:
function_data = tool_call.get("function", {})
Expand All @@ -82,6 +107,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
tool_instance = self.tool_map[tool_name]
try:
tool_result = await tool_instance.run(**args)
tools_used.append(tool_name)
if is_debug():
print(colorize_text(f"<Tool Result: {colorize_text(tool_name, "green")}> ", "yellow"), prettify(tool_result))
except Exception as e:
Expand All @@ -99,3 +125,5 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
"tool_call_id": tool_call.get("id", "unknown_tool"),
"content": json.dumps(tool_result, default=complex_handler),
})

return tools_used
11 changes: 11 additions & 0 deletions src/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
class Tool:
@property
def enabled(self) -> bool:
return self._enabled

@property
def name(self) -> str:
return self._name
Expand All @@ -21,6 +25,13 @@ def __init__(self,
self._description = description
self._parameters = parameters if parameters is not None else {}
self._session = session
self._enabled = True

def enable(self) -> None:
self._enabled = True

def disable(self) -> None:
self._enabled = False

def define(self):
return {
Expand Down
4 changes: 2 additions & 2 deletions src/tools/github_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def name(self) -> str:

@property
def description(self) -> str:
return "The only reliable Knowledgebase on GitHub topics. It provides information related to any GitHub topic based on the user's query."
return "Search the comprehensive GitHub knowledge base using advanced vector embeddings and semantic search to find authoritative information on GitHub-related topics, features, APIs, and best practices. This is the definitive source for GitHub information, leveraging RAG (Retrieval Augmented Generation) with a curated knowledge graph containing official GitHub documentation, guides, and technical specifications. Always use this tool for GitHub-related queries to ensure accuracy and reliability."

@property
def parameters(self) -> dict:
Expand All @@ -16,7 +16,7 @@ def parameters(self) -> dict:
"properties": {
"query": {
"type": "string",
"description": "The user query related to any GitHub topic."
"description": "Natural language query about any GitHub topic including features, APIs, Actions, repositories, issues, pull requests, security, integrations, or development workflows. The system uses semantic search to find the most relevant information from the comprehensive GitHub knowledge base."
}
},
"required": ["query"]
Expand Down
6 changes: 3 additions & 3 deletions src/tools/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def name(self) -> str:

@property
def description(self) -> str:
return "Search the web for relevant information."
return "Perform web searches using Google Custom Search API to retrieve relevant information from the internet. Returns structured search results including titles, URLs, snippets, and metadata. Supports configurable result limits and provides search performance metrics. Requires valid Google API credentials and custom search engine configuration."

@property
def parameters(self) -> dict:
Expand All @@ -16,11 +16,11 @@ def parameters(self) -> dict:
"properties": {
"query": {
"type": "string",
"description": "The search query to use"
"description": "The search query string to submit to Google. Supports standard Google search operators and syntax including quotes for exact phrases, site: for domain filtering, and boolean operators."
},
"num_results": {
"type": "number",
"description": "Number of results to return (default: 5, max: 10)"
"description": "Maximum number of search results to return. Valid range is 1-10, defaults to 5 if not specified. Higher values may increase API quota usage and response time."
}
},
"required": ["query"]
Expand Down
6 changes: 3 additions & 3 deletions src/tools/list_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def name(self) -> str:

@property
def description(self) -> str:
return "List files in a specified directory within a secure base directory."
return "List files and directories within a specified directory path, constrained to operate within a secure base directory for security. Returns comprehensive file listing with metadata including file names, types, and directory structure. Supports recursive directory traversal within security boundaries."

@property
def parameters(self) -> dict:
Expand All @@ -17,11 +17,11 @@ def parameters(self) -> dict:
"properties": {
"base_dir": {
"type": "string",
"description": "Base directory for file operations"
"description": "Absolute path to the base directory that serves as the security boundary for all file operations. All file access is restricted to this directory and its subdirectories."
},
"directory": {
"type": "string",
"description": "The relative subdirectory path to list files from (default: '.')"
"description": "Relative path to the subdirectory within base_dir to list files from. Defaults to '.' for current directory. Path traversal attacks are prevented by security validation."
}
},
"required": ["base_dir"]
Expand Down
6 changes: 3 additions & 3 deletions src/tools/read_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def name(self) -> str:

@property
def description(self) -> str:
return "Read content from a specified file within a secure base directory."
return "Read and return the complete content of a specified file within security-constrained base directory boundaries. Supports text and binary file reading with proper error handling for missing files, permission issues, and encoding problems. File access is restricted to the specified base directory to prevent path traversal vulnerabilities."

@property
def parameters(self) -> dict:
Expand All @@ -16,11 +16,11 @@ def parameters(self) -> dict:
"properties": {
"base_dir": {
"type": "string",
"description": "Base directory for file operations"
"description": "Absolute path to the base directory that serves as the security boundary for all file operations. All file access is restricted to this directory and its subdirectories."
},
"filename": {
"type": "string",
"description": "The name of the file to read from (relative path)"
"description": "Relative path to the target file within base_dir. Can include subdirectory paths. Path traversal attempts (../) are automatically prevented by security validation."
}
},
"required": ["base_dir", "filename"]
Expand Down
8 changes: 4 additions & 4 deletions src/tools/write_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def name(self) -> str:

@property
def description(self) -> str:
return "Write content to a specified file within a secure base directory."
return "Write or overwrite content to a specified file within security-constrained base directory boundaries. Creates directories as needed and handles text encoding automatically. Supports creating new files or updating existing ones with comprehensive error handling for permission issues, disk space, and path validation. File operations are restricted to the specified base directory to prevent path traversal vulnerabilities."

@property
def parameters(self) -> dict:
Expand All @@ -16,15 +16,15 @@ def parameters(self) -> dict:
"properties": {
"base_dir": {
"type": "string",
"description": "Base directory for file operations"
"description": "Absolute path to the base directory that serves as the security boundary for all file operations. All file access is restricted to this directory and its subdirectories."
},
"filename": {
"type": "string",
"description": "The name of the file to write to (relative path)"
"description": "Relative path to the target file within base_dir. Can include subdirectory paths which will be created if they don't exist. Path traversal attempts (../) are automatically prevented by security validation."
},
"content": {
"type": "string",
"description": "The content to write to the file"
"description": "The text content to write to the file. Existing file content will be completely replaced. Unicode and special characters are supported with automatic encoding handling."
}
},
"required": ["base_dir", "filename", "content"]
Expand Down
Loading