Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
798 changes: 704 additions & 94 deletions README.md

Large diffs are not rendered by default.

Binary file added docs/chat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 4 additions & 27 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
load_dotenv(override=True)

import asyncio
from functools import wraps

from datetime import date

Expand All @@ -18,24 +17,6 @@

_agent_sessions = {}

def persist_session(func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(self, *args, **kwargs):
result = await func(self, *args, **kwargs)
if hasattr(self, 'session_id'):
_agent_sessions[self.session_id] = self
return result
return async_wrapper

@wraps(func)
def sync_wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if hasattr(self, 'session_id'):
_agent_sessions[self.session_id] = self
return result
return sync_wrapper

class Agent:
def __init__(self, session_id: str = "cli-session"):
self.tools = {
Expand All @@ -45,7 +26,7 @@ def __init__(self, session_id: str = "cli-session"):
WriteFile(),
ListFiles(),
}
self.chat = Chat.create(self.tools)
self.chat = Chat.create(self.tools, session_id)

# Enable all tools by default
for tool in self.chat.tools:
Expand All @@ -56,14 +37,12 @@ def __init__(self, session_id: str = "cli-session"):
self.history = []
self._update_system_prompt()

@persist_session
def _get_available_tools_text(self) -> str:
enabled_tools = [tool.name for tool in self.chat.tools if tool.enabled]
if not enabled_tools:
return "No tools are currently available."
return f"Currently available tools: {', '.join(enabled_tools)}"

@persist_session
def _update_system_prompt(self) -> None:
available_tools_text = self._get_available_tools_text()

Expand Down Expand Up @@ -114,12 +93,10 @@ def _update_system_prompt(self) -> None:
else:
self.history.insert(0, {"role": "system", "content": system_role})

@persist_session
def add_tool(self, tool: Tool) -> None:
self.chat.add_tool(tool)
self._update_system_prompt()

@persist_session
def enable_tool(self, tool_name: str) -> bool:
try:
self.chat.enable_tool(tool_name)
Expand All @@ -128,7 +105,6 @@ def enable_tool(self, tool_name: str) -> bool:
return False
return True

@persist_session
def disable_tool(self, tool_name: str) -> bool:
try:
self.chat.disable_tool(tool_name)
Expand All @@ -137,11 +113,9 @@ def disable_tool(self, tool_name: str) -> bool:
return False
return True

@persist_session
def get_tools(self) -> list:
return self.chat.get_tools()

@persist_session
async def process_query(self, user_prompt: str) -> str:
user_role = {"role": "user", "content": user_prompt}

Expand Down Expand Up @@ -189,8 +163,11 @@ def get_agent_instance(session_id: str = None) -> Agent:


def delete_agent_instance(session_id: str) -> bool:
from core.debug_capture import delete_debug_capture_instance

if session_id in _agent_sessions:
del _agent_sessions[session_id]
delete_debug_capture_instance(session_id)
return True
return False

Expand Down
25 changes: 16 additions & 9 deletions src/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
QueryRequest, QueryResponse, ToolsListResponse, ToolToggleRequest,
ToolToggleResponse, ToolInfo, DebugResponse, DebugRequest, NewSessionResponse
)
from core.debug_capture import debug_capture
from core.debug_capture import get_debug_capture_instance, get_all_debug_events, clear_all_debug_events, delete_debug_capture_instance
from core.mcp.sessions_manager import MCPSessionManager

session_router = APIRouter(prefix="/api")
Expand All @@ -18,6 +18,8 @@
async def new_session():
session_id = str(uuid.uuid4())
agent_instance = get_agent_instance(session_id)

get_debug_capture_instance(session_id)

try:
config_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config', 'mcp.json')
Expand All @@ -33,6 +35,8 @@ async def new_session():
@session_router.delete("/session/{session_id}")
async def delete_session(session_id: str):
if delete_agent_instance(session_id):
# Also clean up the debug capture instance for this session
delete_debug_capture_instance(session_id)
return Response(status_code=204)
else:
raise HTTPException(status_code=404, detail="Session not found")
Expand All @@ -41,7 +45,7 @@ async def delete_session(session_id: str):
@api_router.post("/ask", response_model=QueryResponse)
async def ask_agent(session_id: str, request: QueryRequest, agent_instance: Agent = Depends(get_agent_instance)) -> QueryResponse:
try:
debug_capture.set_session_id(session_id)
# Debug capture is now per-session, no need to set session_id
response, used_tools = await agent_instance.process_query(request.query)
return QueryResponse(
response=response,
Expand Down Expand Up @@ -99,7 +103,7 @@ async def toggle_tool(request: ToolToggleRequest, agent_instance: Agent = Depend
@api_router.get("/debug", response_model=DebugResponse)
async def get_debug_info(session_id: str) -> DebugResponse:
try:
events = debug_capture.get_events(session_id)
events = get_all_debug_events(session_id)
debug_events = [
{
"event_type": event["event_type"],
Expand All @@ -110,28 +114,31 @@ async def get_debug_info(session_id: str) -> DebugResponse:
}
for event in events
]
return DebugResponse(events=debug_events, enabled=debug_capture.is_enabled())
# For checking if enabled, use the specific session
capture = get_debug_capture_instance(session_id)
return DebugResponse(events=debug_events, enabled=capture.is_enabled())
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving debug info: {str(e)}")


@api_router.post("/debug/toggle", response_model=DebugResponse)
async def toggle_debug(request: DebugRequest) -> DebugResponse:
async def toggle_debug(session_id: str, request: DebugRequest) -> DebugResponse:
try:
capture = get_debug_capture_instance(session_id)
if request.enabled:
debug_capture.enable()
capture.enable()
else:
debug_capture.disable()
capture.disable()

return DebugResponse(events=[], enabled=debug_capture.is_enabled())
return DebugResponse(events=[], enabled=capture.is_enabled())
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling debug: {str(e)}")


@api_router.delete("/debug")
async def clear_debug_events(session_id: str) -> Response:
try:
debug_capture.clear_events(session_id)
clear_all_debug_events(session_id)
return Response(status_code=204)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error clearing debug events: {str(e)}")
6 changes: 3 additions & 3 deletions src/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def is_debug(self) -> bool:
set_debug = debugger.set_debug

# Import debug_capture after the debugger is set up to avoid circular imports
def get_debug_capture():
def get_debug_capture(session_id: str = "default"):
try:
from core.debug_capture import debug_capture
return debug_capture
from core.debug_capture import get_debug_capture_instance
return get_debug_capture_instance(session_id)
except ImportError:
return None

Expand Down
87 changes: 54 additions & 33 deletions src/core/debug_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ def __init__(
self,
event_type: DebugEventType,
message: str,
session_id: str,
data: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime] = None
):
self.event_type = event_type
self.message = message
self.data = data or {}
self.timestamp = timestamp or datetime.now()
self.session_id = DebugCapture.get_current_session_id()
self.session_id = session_id

def to_dict(self) -> Dict[str, Any]:
try:
Expand All @@ -129,19 +130,13 @@ def to_dict(self) -> Dict[str, Any]:
}

class DebugCapture:
_instance = None
_lock = threading.Lock()
_session_id = None

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(DebugCapture, cls).__new__(cls)
cls._instance._events = []
cls._instance._max_events = 1000
cls._instance._enabled = False
return cls._instance

def __init__(self, session_id: str):
self.session_id = session_id
self._events = []
self._max_events = 1000
self._enabled = False

def enable(self):
self._enabled = True
Expand All @@ -152,12 +147,8 @@ def disable(self):
def is_enabled(self) -> bool:
return self._enabled

def set_session_id(self, session_id: str):
DebugCapture._session_id = session_id

@classmethod
def get_current_session_id(cls) -> Optional[str]:
return cls._session_id
def get_current_session_id(self) -> str:
return self.session_id

def capture_event(
self,
Expand All @@ -168,30 +159,22 @@ def capture_event(
if not self._enabled:
return

# Safely serialize the data to prevent encoding issues
safe_data = safe_serialize(data) if data else {}

event = DebugEvent(event_type, message, safe_data)
event = DebugEvent(event_type, message, self.session_id, safe_data)

with self._lock:
self._events.append(event)
# Keep only the last max_events
if len(self._events) > self._max_events:
self._events = self._events[-self._max_events:]

def get_events(self, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
def get_events(self) -> List[Dict[str, Any]]:
with self._lock:
events = self._events
if session_id:
events = [e for e in events if e.session_id == session_id]
return [event.to_dict() for event in events]
return [event.to_dict() for event in self._events]

def clear_events(self, session_id: Optional[str] = None):
def clear_events(self):
with self._lock:
if session_id:
self._events = [e for e in self._events if e.session_id != session_id]
else:
self._events = []
self._events = []

def capture_llm_request(self, payload: Dict[str, Any]):
self.capture_event(
Expand Down Expand Up @@ -242,4 +225,42 @@ def capture_mcp_result(self, tool_name: str, result: Any):
{"tool_name": tool_name, "result": result}
)

debug_capture = DebugCapture()

# Global session management
_debug_sessions = {}

def get_debug_capture_instance(session_id: str) -> DebugCapture:
if not session_id:
raise ValueError("Session ID must be provided to get debug capture instance.")

if session_id not in _debug_sessions:
_debug_sessions[session_id] = DebugCapture(session_id)

return _debug_sessions[session_id]

def delete_debug_capture_instance(session_id: str) -> bool:
if session_id in _debug_sessions:
del _debug_sessions[session_id]
return True
return False

def get_all_debug_events(session_id: Optional[str] = None) -> List[Dict[str, Any]]:
if session_id:
if session_id in _debug_sessions:
return _debug_sessions[session_id].get_events()
return []

all_events = []
for capture in _debug_sessions.values():
all_events.extend(capture.get_events())

all_events.sort(key=lambda x: x['timestamp'])
return all_events

def clear_all_debug_events(session_id: Optional[str] = None):
if session_id:
if session_id in _debug_sessions:
_debug_sessions[session_id].clear_events()
else:
for capture in _debug_sessions.values():
capture.clear_events()
15 changes: 8 additions & 7 deletions src/core/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

from core import is_debug
class Chat:
def __init__(self, tool_list: List[Tool] = []):
self.chat_client: ChatClient = ChatClient()
def __init__(self, tool_list: List[Tool] = [], session_id: str = "default"):
self.chat_client: ChatClient = ChatClient(session_id=session_id)
self.tool_map = {tool.name: tool for tool in tool_list}
self.tools: List[Tool] = [tool for tool in tool_list]
self.session_id = session_id

def add_tool(self, tool: Tool) -> None:
self.tool_map[tool.name] = tool
Expand Down Expand Up @@ -48,7 +49,7 @@ def _set_tool_state(self, tool_name: str, active = True) -> None:
raise ValueError(f"Tool '{tool_name}' not found in the chat tools.")

@classmethod
def create(cls, tool_list = []) -> 'Chat':
def create(cls, tool_list = [], session_id: str = "default") -> 'Chat':
api_key = os.environ.get(DEFAULT_API_KEY_ENV)
if not api_key:
raise ValueError(f"{DEFAULT_API_KEY_ENV} environment variable is required")
Expand All @@ -58,7 +59,7 @@ def create(cls, tool_list = []) -> 'Chat':
print(colorize_text(f"<Tool Initialized: {colorize_text(tool.name, "yellow")}>", "cyan"))
print("\n")

return cls(tool_list)
return cls(tool_list, session_id)

async def send_messages(
self,
Expand Down Expand Up @@ -98,7 +99,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
except json.JSONDecodeError:
args = {}

debug_capture = get_debug_capture()
debug_capture = get_debug_capture(self.session_id)
if debug_capture:
debug_capture.capture_tool_call(tool_name, args)

Expand All @@ -113,7 +114,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
tools_used.append(tool_name)
if is_debug():
print(colorize_text(f"<Tool Result: {colorize_text(tool_name, "green")}> ", "yellow"), prettify(tool_result))
debug_capture = get_debug_capture()
debug_capture = get_debug_capture(self.session_id)
if debug_capture:
debug_capture.capture_tool_result(tool_name, tool_result)
except Exception as e:
Expand All @@ -122,7 +123,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
}
if is_debug():
print(colorize_text(f"<Tool Exception: {colorize_text(tool_name, "red")}> ", "yellow"), str(e))
debug_capture = get_debug_capture()
debug_capture = get_debug_capture(self.session_id)
if debug_capture:
debug_capture.capture_tool_error(tool_name, str(e))

Expand Down
Loading