Skip to content

Commit 4027bd0

Browse files
authored
Merge pull request #23 from nullchimp/session_management
Session management
2 parents ae5994d + c878142 commit 4027bd0

28 files changed

+3302
-489
lines changed

README.md

Lines changed: 704 additions & 94 deletions
Large diffs are not rendered by default.

docs/chat.png

881 KB
Loading

src/agent.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
load_dotenv(override=True)
33

44
import asyncio
5-
from functools import wraps
65

76
from datetime import date
87

@@ -18,24 +17,6 @@
1817

1918
_agent_sessions = {}
2019

21-
def persist_session(func):
22-
if asyncio.iscoroutinefunction(func):
23-
@wraps(func)
24-
async def async_wrapper(self, *args, **kwargs):
25-
result = await func(self, *args, **kwargs)
26-
if hasattr(self, 'session_id'):
27-
_agent_sessions[self.session_id] = self
28-
return result
29-
return async_wrapper
30-
31-
@wraps(func)
32-
def sync_wrapper(self, *args, **kwargs):
33-
result = func(self, *args, **kwargs)
34-
if hasattr(self, 'session_id'):
35-
_agent_sessions[self.session_id] = self
36-
return result
37-
return sync_wrapper
38-
3920
class Agent:
4021
def __init__(self, session_id: str = "cli-session"):
4122
self.tools = {
@@ -45,7 +26,7 @@ def __init__(self, session_id: str = "cli-session"):
4526
WriteFile(),
4627
ListFiles(),
4728
}
48-
self.chat = Chat.create(self.tools)
29+
self.chat = Chat.create(self.tools, session_id)
4930

5031
# Enable all tools by default
5132
for tool in self.chat.tools:
@@ -56,14 +37,12 @@ def __init__(self, session_id: str = "cli-session"):
5637
self.history = []
5738
self._update_system_prompt()
5839

59-
@persist_session
6040
def _get_available_tools_text(self) -> str:
6141
enabled_tools = [tool.name for tool in self.chat.tools if tool.enabled]
6242
if not enabled_tools:
6343
return "No tools are currently available."
6444
return f"Currently available tools: {', '.join(enabled_tools)}"
6545

66-
@persist_session
6746
def _update_system_prompt(self) -> None:
6847
available_tools_text = self._get_available_tools_text()
6948

@@ -114,12 +93,10 @@ def _update_system_prompt(self) -> None:
11493
else:
11594
self.history.insert(0, {"role": "system", "content": system_role})
11695

117-
@persist_session
11896
def add_tool(self, tool: Tool) -> None:
11997
self.chat.add_tool(tool)
12098
self._update_system_prompt()
12199

122-
@persist_session
123100
def enable_tool(self, tool_name: str) -> bool:
124101
try:
125102
self.chat.enable_tool(tool_name)
@@ -128,7 +105,6 @@ def enable_tool(self, tool_name: str) -> bool:
128105
return False
129106
return True
130107

131-
@persist_session
132108
def disable_tool(self, tool_name: str) -> bool:
133109
try:
134110
self.chat.disable_tool(tool_name)
@@ -137,11 +113,9 @@ def disable_tool(self, tool_name: str) -> bool:
137113
return False
138114
return True
139115

140-
@persist_session
141116
def get_tools(self) -> list:
142117
return self.chat.get_tools()
143118

144-
@persist_session
145119
async def process_query(self, user_prompt: str) -> str:
146120
user_role = {"role": "user", "content": user_prompt}
147121

@@ -189,8 +163,11 @@ def get_agent_instance(session_id: str = None) -> Agent:
189163

190164

191165
def delete_agent_instance(session_id: str) -> bool:
166+
from core.debug_capture import delete_debug_capture_instance
167+
192168
if session_id in _agent_sessions:
193169
del _agent_sessions[session_id]
170+
delete_debug_capture_instance(session_id)
194171
return True
195172
return False
196173

src/api/routes.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
QueryRequest, QueryResponse, ToolsListResponse, ToolToggleRequest,
99
ToolToggleResponse, ToolInfo, DebugResponse, DebugRequest, NewSessionResponse
1010
)
11-
from core.debug_capture import debug_capture
11+
from core.debug_capture import get_debug_capture_instance, get_all_debug_events, clear_all_debug_events, delete_debug_capture_instance
1212
from core.mcp.sessions_manager import MCPSessionManager
1313

1414
session_router = APIRouter(prefix="/api")
@@ -18,6 +18,8 @@
1818
async def new_session():
1919
session_id = str(uuid.uuid4())
2020
agent_instance = get_agent_instance(session_id)
21+
22+
get_debug_capture_instance(session_id)
2123

2224
try:
2325
config_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config', 'mcp.json')
@@ -33,6 +35,8 @@ async def new_session():
3335
@session_router.delete("/session/{session_id}")
3436
async def delete_session(session_id: str):
3537
if delete_agent_instance(session_id):
38+
# Also clean up the debug capture instance for this session
39+
delete_debug_capture_instance(session_id)
3640
return Response(status_code=204)
3741
else:
3842
raise HTTPException(status_code=404, detail="Session not found")
@@ -41,7 +45,7 @@ async def delete_session(session_id: str):
4145
@api_router.post("/ask", response_model=QueryResponse)
4246
async def ask_agent(session_id: str, request: QueryRequest, agent_instance: Agent = Depends(get_agent_instance)) -> QueryResponse:
4347
try:
44-
debug_capture.set_session_id(session_id)
48+
# Debug capture is now per-session, no need to set session_id
4549
response, used_tools = await agent_instance.process_query(request.query)
4650
return QueryResponse(
4751
response=response,
@@ -99,7 +103,7 @@ async def toggle_tool(request: ToolToggleRequest, agent_instance: Agent = Depend
99103
@api_router.get("/debug", response_model=DebugResponse)
100104
async def get_debug_info(session_id: str) -> DebugResponse:
101105
try:
102-
events = debug_capture.get_events(session_id)
106+
events = get_all_debug_events(session_id)
103107
debug_events = [
104108
{
105109
"event_type": event["event_type"],
@@ -110,28 +114,31 @@ async def get_debug_info(session_id: str) -> DebugResponse:
110114
}
111115
for event in events
112116
]
113-
return DebugResponse(events=debug_events, enabled=debug_capture.is_enabled())
117+
# For checking if enabled, use the specific session
118+
capture = get_debug_capture_instance(session_id)
119+
return DebugResponse(events=debug_events, enabled=capture.is_enabled())
114120
except Exception as e:
115121
raise HTTPException(status_code=500, detail=f"Error retrieving debug info: {str(e)}")
116122

117123

118124
@api_router.post("/debug/toggle", response_model=DebugResponse)
119-
async def toggle_debug(request: DebugRequest) -> DebugResponse:
125+
async def toggle_debug(session_id: str, request: DebugRequest) -> DebugResponse:
120126
try:
127+
capture = get_debug_capture_instance(session_id)
121128
if request.enabled:
122-
debug_capture.enable()
129+
capture.enable()
123130
else:
124-
debug_capture.disable()
131+
capture.disable()
125132

126-
return DebugResponse(events=[], enabled=debug_capture.is_enabled())
133+
return DebugResponse(events=[], enabled=capture.is_enabled())
127134
except Exception as e:
128135
raise HTTPException(status_code=500, detail=f"Error toggling debug: {str(e)}")
129136

130137

131138
@api_router.delete("/debug")
132139
async def clear_debug_events(session_id: str) -> Response:
133140
try:
134-
debug_capture.clear_events(session_id)
141+
clear_all_debug_events(session_id)
135142
return Response(status_code=204)
136143
except Exception as e:
137144
raise HTTPException(status_code=500, detail=f"Error clearing debug events: {str(e)}")

src/core/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def is_debug(self) -> bool:
2323
set_debug = debugger.set_debug
2424

2525
# Import debug_capture after the debugger is set up to avoid circular imports
26-
def get_debug_capture():
26+
def get_debug_capture(session_id: str = "default"):
2727
try:
28-
from core.debug_capture import debug_capture
29-
return debug_capture
28+
from core.debug_capture import get_debug_capture_instance
29+
return get_debug_capture_instance(session_id)
3030
except ImportError:
3131
return None
3232

src/core/debug_capture.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,15 @@ def __init__(
100100
self,
101101
event_type: DebugEventType,
102102
message: str,
103+
session_id: str,
103104
data: Optional[Dict[str, Any]] = None,
104105
timestamp: Optional[datetime] = None
105106
):
106107
self.event_type = event_type
107108
self.message = message
108109
self.data = data or {}
109110
self.timestamp = timestamp or datetime.now()
110-
self.session_id = DebugCapture.get_current_session_id()
111+
self.session_id = session_id
111112

112113
def to_dict(self) -> Dict[str, Any]:
113114
try:
@@ -129,19 +130,13 @@ def to_dict(self) -> Dict[str, Any]:
129130
}
130131

131132
class DebugCapture:
132-
_instance = None
133133
_lock = threading.Lock()
134-
_session_id = None
135-
136-
def __new__(cls):
137-
if cls._instance is None:
138-
with cls._lock:
139-
if cls._instance is None:
140-
cls._instance = super(DebugCapture, cls).__new__(cls)
141-
cls._instance._events = []
142-
cls._instance._max_events = 1000
143-
cls._instance._enabled = False
144-
return cls._instance
134+
135+
def __init__(self, session_id: str):
136+
self.session_id = session_id
137+
self._events = []
138+
self._max_events = 1000
139+
self._enabled = False
145140

146141
def enable(self):
147142
self._enabled = True
@@ -152,12 +147,8 @@ def disable(self):
152147
def is_enabled(self) -> bool:
153148
return self._enabled
154149

155-
def set_session_id(self, session_id: str):
156-
DebugCapture._session_id = session_id
157-
158-
@classmethod
159-
def get_current_session_id(cls) -> Optional[str]:
160-
return cls._session_id
150+
def get_current_session_id(self) -> str:
151+
return self.session_id
161152

162153
def capture_event(
163154
self,
@@ -168,30 +159,22 @@ def capture_event(
168159
if not self._enabled:
169160
return
170161

171-
# Safely serialize the data to prevent encoding issues
172162
safe_data = safe_serialize(data) if data else {}
173163

174-
event = DebugEvent(event_type, message, safe_data)
164+
event = DebugEvent(event_type, message, self.session_id, safe_data)
175165

176166
with self._lock:
177167
self._events.append(event)
178-
# Keep only the last max_events
179168
if len(self._events) > self._max_events:
180169
self._events = self._events[-self._max_events:]
181170

182-
def get_events(self, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
171+
def get_events(self) -> List[Dict[str, Any]]:
183172
with self._lock:
184-
events = self._events
185-
if session_id:
186-
events = [e for e in events if e.session_id == session_id]
187-
return [event.to_dict() for event in events]
173+
return [event.to_dict() for event in self._events]
188174

189-
def clear_events(self, session_id: Optional[str] = None):
175+
def clear_events(self):
190176
with self._lock:
191-
if session_id:
192-
self._events = [e for e in self._events if e.session_id != session_id]
193-
else:
194-
self._events = []
177+
self._events = []
195178

196179
def capture_llm_request(self, payload: Dict[str, Any]):
197180
self.capture_event(
@@ -242,4 +225,42 @@ def capture_mcp_result(self, tool_name: str, result: Any):
242225
{"tool_name": tool_name, "result": result}
243226
)
244227

245-
debug_capture = DebugCapture()
228+
229+
# Global session management
230+
_debug_sessions = {}
231+
232+
def get_debug_capture_instance(session_id: str) -> DebugCapture:
233+
if not session_id:
234+
raise ValueError("Session ID must be provided to get debug capture instance.")
235+
236+
if session_id not in _debug_sessions:
237+
_debug_sessions[session_id] = DebugCapture(session_id)
238+
239+
return _debug_sessions[session_id]
240+
241+
def delete_debug_capture_instance(session_id: str) -> bool:
242+
if session_id in _debug_sessions:
243+
del _debug_sessions[session_id]
244+
return True
245+
return False
246+
247+
def get_all_debug_events(session_id: Optional[str] = None) -> List[Dict[str, Any]]:
248+
if session_id:
249+
if session_id in _debug_sessions:
250+
return _debug_sessions[session_id].get_events()
251+
return []
252+
253+
all_events = []
254+
for capture in _debug_sessions.values():
255+
all_events.extend(capture.get_events())
256+
257+
all_events.sort(key=lambda x: x['timestamp'])
258+
return all_events
259+
260+
def clear_all_debug_events(session_id: Optional[str] = None):
261+
if session_id:
262+
if session_id in _debug_sessions:
263+
_debug_sessions[session_id].clear_events()
264+
else:
265+
for capture in _debug_sessions.values():
266+
capture.clear_events()

src/core/llm/chat.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
from core import is_debug
1717
class Chat:
18-
def __init__(self, tool_list: List[Tool] = []):
19-
self.chat_client: ChatClient = ChatClient()
18+
def __init__(self, tool_list: List[Tool] = [], session_id: str = "default"):
19+
self.chat_client: ChatClient = ChatClient(session_id=session_id)
2020
self.tool_map = {tool.name: tool for tool in tool_list}
2121
self.tools: List[Tool] = [tool for tool in tool_list]
22+
self.session_id = session_id
2223

2324
def add_tool(self, tool: Tool) -> None:
2425
self.tool_map[tool.name] = tool
@@ -48,7 +49,7 @@ def _set_tool_state(self, tool_name: str, active = True) -> None:
4849
raise ValueError(f"Tool '{tool_name}' not found in the chat tools.")
4950

5051
@classmethod
51-
def create(cls, tool_list = []) -> 'Chat':
52+
def create(cls, tool_list = [], session_id: str = "default") -> 'Chat':
5253
api_key = os.environ.get(DEFAULT_API_KEY_ENV)
5354
if not api_key:
5455
raise ValueError(f"{DEFAULT_API_KEY_ENV} environment variable is required")
@@ -58,7 +59,7 @@ def create(cls, tool_list = []) -> 'Chat':
5859
print(colorize_text(f"<Tool Initialized: {colorize_text(tool.name, "yellow")}>", "cyan"))
5960
print("\n")
6061

61-
return cls(tool_list)
62+
return cls(tool_list, session_id)
6263

6364
async def send_messages(
6465
self,
@@ -98,7 +99,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
9899
except json.JSONDecodeError:
99100
args = {}
100101

101-
debug_capture = get_debug_capture()
102+
debug_capture = get_debug_capture(self.session_id)
102103
if debug_capture:
103104
debug_capture.capture_tool_call(tool_name, args)
104105

@@ -113,7 +114,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
113114
tools_used.append(tool_name)
114115
if is_debug():
115116
print(colorize_text(f"<Tool Result: {colorize_text(tool_name, "green")}> ", "yellow"), prettify(tool_result))
116-
debug_capture = get_debug_capture()
117+
debug_capture = get_debug_capture(self.session_id)
117118
if debug_capture:
118119
debug_capture.capture_tool_result(tool_name, tool_result)
119120
except Exception as e:
@@ -122,7 +123,7 @@ async def process_tool_calls(self, response: Dict[str, Any], call_back) -> None:
122123
}
123124
if is_debug():
124125
print(colorize_text(f"<Tool Exception: {colorize_text(tool_name, "red")}> ", "yellow"), str(e))
125-
debug_capture = get_debug_capture()
126+
debug_capture = get_debug_capture(self.session_id)
126127
if debug_capture:
127128
debug_capture.capture_tool_error(tool_name, str(e))
128129

0 commit comments

Comments
 (0)