Skip to content

Commit c5055cb

Browse files
committed
πŸ› Use conversation_id and user_id as the primary key of the thread pool
1 parent 9a0072c commit c5055cb

File tree

7 files changed

+74
-40
lines changed

7 files changed

+74
-40
lines changed

β€Žbackend/agents/agent_run_manager.pyβ€Ž

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,46 @@ def __new__(cls):
2121

2222
def __init__(self):
2323
if not self._initialized:
24-
# conversation_id -> agent_run_info
25-
self.agent_runs: Dict[int, AgentRunInfo] = {}
24+
# user_id:conversation_id -> agent_run_info
25+
self.agent_runs: Dict[str, AgentRunInfo] = {}
2626
self._initialized = True
2727

28-
def register_agent_run(self, conversation_id: int, agent_run_info):
28+
def _get_run_key(self, conversation_id: int, user_id: str) -> str:
29+
"""Generate unique key for agent run using user_id and conversation_id"""
30+
return f"{user_id}:{conversation_id}"
31+
32+
def register_agent_run(self, conversation_id: int, agent_run_info, user_id: str):
2933
"""register agent run instance"""
3034
with self._lock:
31-
self.agent_runs[conversation_id] = agent_run_info
35+
run_key = self._get_run_key(conversation_id, user_id)
36+
self.agent_runs[run_key] = agent_run_info
3237
logger.info(
33-
f"register agent run instance, conversation_id: {conversation_id}")
38+
f"register agent run instance, user_id: {user_id}, conversation_id: {conversation_id}")
3439

35-
def unregister_agent_run(self, conversation_id: int):
40+
def unregister_agent_run(self, conversation_id: int, user_id: str):
3641
"""unregister agent run instance"""
3742
with self._lock:
38-
if conversation_id in self.agent_runs:
39-
del self.agent_runs[conversation_id]
43+
run_key = self._get_run_key(conversation_id, user_id)
44+
if run_key in self.agent_runs:
45+
del self.agent_runs[run_key]
4046
logger.info(
41-
f"unregister agent run instance, conversation_id: {conversation_id}")
47+
f"unregister agent run instance, user_id: {user_id}, conversation_id: {conversation_id}")
4248
else:
4349
logger.info(
44-
f"no agent run instance found for conversation_id: {conversation_id}")
50+
f"no agent run instance found for user_id: {user_id}, conversation_id: {conversation_id}")
4551

46-
def get_agent_run_info(self, conversation_id: int):
52+
def get_agent_run_info(self, conversation_id: int, user_id: str):
4753
"""get agent run instance"""
48-
return self.agent_runs.get(conversation_id)
54+
run_key = self._get_run_key(conversation_id, user_id)
55+
return self.agent_runs.get(run_key)
4956

50-
def stop_agent_run(self, conversation_id: int) -> bool:
51-
"""stop agent run for specified conversation_id"""
52-
agent_run_info = self.get_agent_run_info(conversation_id)
57+
def stop_agent_run(self, conversation_id: int, user_id: str) -> bool:
58+
"""stop agent run for specified conversation_id and user_id"""
59+
agent_run_info = self.get_agent_run_info(conversation_id, user_id)
5360
if agent_run_info is not None:
5461
agent_run_info.stop_event.set()
5562
logger.info(
56-
f"agent run stopped, conversation_id: {conversation_id}")
63+
f"agent run stopped, user_id: {user_id}, conversation_id: {conversation_id}")
5764
return True
5865
return False
5966

β€Žbackend/apps/agent_app.pyβ€Ž

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ async def agent_run_api(agent_request: AgentRequest, http_request: Request, auth
4545

4646

4747
@router.get("/stop/{conversation_id}")
48-
async def agent_stop_api(conversation_id: int):
48+
async def agent_stop_api(conversation_id: int, authorization: Optional[str] = Header(None)):
4949
"""
5050
stop agent run and preprocess tasks for specified conversation_id
5151
"""
52-
if stop_agent_tasks(conversation_id).get("status") == "success":
52+
user_id, _ = get_current_user_id(authorization)
53+
if stop_agent_tasks(conversation_id, user_id).get("status") == "success":
5354
return {"status": "success", "message": "agent run and preprocess tasks stopped successfully"}
5455
else:
5556
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,

β€Žbackend/services/agent_service.pyβ€Ž

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def _stream_agent_chunks(
125125
user_id=user_id,
126126
)
127127
# Always unregister the run to release resources
128-
agent_run_manager.unregister_agent_run(agent_request.conversation_id)
128+
agent_run_manager.unregister_agent_run(agent_request.conversation_id, user_id)
129129

130130
# Schedule memory addition in background to avoid blocking SSE termination
131131
async def _add_memory_background():
@@ -681,7 +681,7 @@ async def prepare_agent_run(
681681
allow_memory_search=allow_memory_search,
682682
)
683683
agent_run_manager.register_agent_run(
684-
agent_request.conversation_id, agent_run_info)
684+
agent_request.conversation_id, agent_run_info, user_id)
685685
return agent_run_info, memory_context
686686

687687

@@ -881,13 +881,13 @@ async def run_agent_stream(
881881
)
882882

883883

884-
def stop_agent_tasks(conversation_id: int):
884+
def stop_agent_tasks(conversation_id: int, user_id: str):
885885
"""
886886
Stop agent run and preprocess tasks for the specified conversation_id.
887887
Matches the behavior of agent_app.agent_stop_api.
888888
"""
889889
# Stop agent run
890-
agent_stopped = agent_run_manager.stop_agent_run(conversation_id)
890+
agent_stopped = agent_run_manager.stop_agent_run(conversation_id, user_id)
891891

892892
# Stop preprocess tasks
893893
preprocess_stopped = preprocess_manager.stop_preprocess_tasks(
@@ -900,11 +900,11 @@ def stop_agent_tasks(conversation_id: int):
900900
if preprocess_stopped:
901901
message_parts.append("preprocess tasks")
902902

903-
message = f"successfully stopped {' and '.join(message_parts)} for conversation_id {conversation_id}"
903+
message = f"successfully stopped {' and '.join(message_parts)} for user_id {user_id}, conversation_id {conversation_id}"
904904
logging.info(message)
905905
return {"status": "success", "message": message}
906906
else:
907-
message = f"no running agent or preprocess tasks found for conversation_id {conversation_id}"
907+
message = f"no running agent or preprocess tasks found for user_id {user_id}, conversation_id {conversation_id}"
908908
logging.error(message)
909909
return {"status": "error", "message": message}
910910

β€Žbackend/services/northbound_service.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ async def stop_chat(ctx: NorthboundContext, external_conversation_id: str) -> Di
208208
try:
209209
internal_id = await to_internal_conversation_id(external_conversation_id)
210210

211-
stop_result = stop_agent_tasks(internal_id)
211+
stop_result = stop_agent_tasks(internal_id, ctx.user_id)
212212
return {"message": stop_result.get("message", "success"), "data": external_conversation_id, "requestId": ctx.request_id}
213213
except Exception as e:
214214
raise Exception(f"Failed to stop chat for external conversation id {external_conversation_id}: {str(e)}")

β€Žtest/backend/app/test_agent_app.pyβ€Ž

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,26 +105,42 @@ async def mock_stream():
105105

106106
def test_agent_stop_api_success(mocker, mock_conversation_id):
107107
"""Test agent_stop_api success case."""
108+
# Mock the authentication function to return user_id
109+
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
110+
mock_get_user_id.return_value = ("test_user_id", "test_tenant_id")
111+
108112
mock_stop_tasks = mocker.patch("apps.agent_app.stop_agent_tasks")
109113
mock_stop_tasks.return_value = {"status": "success"}
110114

111-
response = client.get(f"/agent/stop/{mock_conversation_id}")
115+
response = client.get(
116+
f"/agent/stop/{mock_conversation_id}",
117+
headers={"Authorization": "Bearer test_token"}
118+
)
112119

113120
assert response.status_code == 200
114-
mock_stop_tasks.assert_called_once_with(mock_conversation_id)
121+
mock_get_user_id.assert_called_once_with("Bearer test_token")
122+
mock_stop_tasks.assert_called_once_with(mock_conversation_id, "test_user_id")
115123
assert response.json()["status"] == "success"
116124

117125

118126
def test_agent_stop_api_not_found(mocker, mock_conversation_id):
119127
"""Test agent_stop_api not found case."""
128+
# Mock the authentication function to return user_id
129+
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
130+
mock_get_user_id.return_value = ("test_user_id", "test_tenant_id")
131+
120132
mock_stop_tasks = mocker.patch("apps.agent_app.stop_agent_tasks")
121133
mock_stop_tasks.return_value = {"status": "error"} # Simulate not found
122134

123-
response = client.get(f"/agent/stop/{mock_conversation_id}")
135+
response = client.get(
136+
f"/agent/stop/{mock_conversation_id}",
137+
headers={"Authorization": "Bearer test_token"}
138+
)
124139

125140
# The app should raise HTTPException for non-success status
126141
assert response.status_code == 400
127-
mock_stop_tasks.assert_called_once_with(mock_conversation_id)
142+
mock_get_user_id.assert_called_once_with("Bearer test_token")
143+
mock_stop_tasks.assert_called_once_with(mock_conversation_id, "test_user_id")
128144
assert "no running agent or preprocess tasks found" in response.json()[
129145
"detail"]
130146

β€Žtest/backend/services/test_agent_service.pyβ€Ž

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,7 +1838,7 @@ async def test_prepare_agent_run(
18381838
"test_user", "test_tenant", 1)
18391839
mock_create_run_info.assert_called_once()
18401840
mock_agent_run_manager.register_agent_run.assert_called_once_with(
1841-
123, mock_run_info)
1841+
123, mock_run_info, "test_user")
18421842

18431843

18441844
@patch('backend.services.agent_service.submit')
@@ -1933,21 +1933,24 @@ def test_stop_agent_tasks(mock_preprocess_manager, mock_agent_run_manager):
19331933
# Test both stopped
19341934
mock_agent_run_manager.stop_agent_run.return_value = True
19351935
mock_preprocess_manager.stop_preprocess_tasks.return_value = True
1936-
result = stop_agent_tasks(123)
1936+
1937+
result = stop_agent_tasks(123, "test_user")
19371938
assert result["status"] == "success"
19381939
assert "successfully stopped agent run and preprocess tasks" in result["message"]
19391940

1941+
mock_agent_run_manager.stop_agent_run.assert_called_once_with(123, "test_user")
1942+
19401943
# Test only agent stopped
19411944
mock_agent_run_manager.stop_agent_run.return_value = True
19421945
mock_preprocess_manager.stop_preprocess_tasks.return_value = False
1943-
result = stop_agent_tasks(123)
1946+
result = stop_agent_tasks(123, "test_user")
19441947
assert result["status"] == "success"
19451948
assert "successfully stopped agent run" in result["message"]
19461949

19471950
# Test neither stopped
19481951
mock_agent_run_manager.stop_agent_run.return_value = False
19491952
mock_preprocess_manager.stop_preprocess_tasks.return_value = False
1950-
result = stop_agent_tasks(123)
1953+
result = stop_agent_tasks(123, "test_user")
19511954
assert result["status"] == "error"
19521955
assert "no running agent or preprocess tasks found" in result["message"]
19531956

@@ -2340,11 +2343,11 @@ def fake_save_messages(*args, **kwargs):
23402343
raising=False,
23412344
)
23422345

2343-
# Mock unregister
23442346
unregister_called = {}
23452347

2346-
def fake_unregister(conv_id):
2348+
def fake_unregister(conv_id, user_id):
23472349
unregister_called["conv_id"] = conv_id
2350+
unregister_called["user_id"] = user_id
23482351

23492352
monkeypatch.setattr(
23502353
"backend.services.agent_service.agent_run_manager.unregister_agent_run",
@@ -2365,6 +2368,7 @@ def fake_unregister(conv_id):
23652368
] # Prefix added in helper
23662369
assert save_calls, "save_messages should have been called for assistant messages"
23672370
assert unregister_called.get("conv_id") == 999
2371+
assert unregister_called.get("user_id") == "u"
23682372

23692373

23702374
@pytest.mark.asyncio
@@ -2386,10 +2390,11 @@ async def failing_agent_run(*_, **__):
23862390
"backend.services.agent_service.agent_run", failing_agent_run, raising=False
23872391
)
23882392

2389-
called = {"unregistered": None}
2393+
called = {"unregistered": None, "user_id": None}
23902394

2391-
def fake_unregister(conv_id):
2395+
def fake_unregister(conv_id, user_id):
23922396
called["unregistered"] = conv_id
2397+
called["user_id"] = user_id
23932398

23942399
monkeypatch.setattr(
23952400
"backend.services.agent_service.agent_run_manager.unregister_agent_run",
@@ -2408,6 +2413,7 @@ def fake_unregister(conv_id):
24082413
assert collected and collected[0].startswith(
24092414
"data: {") and "\"type\": \"error\"" in collected[0]
24102415
assert called["unregistered"] == 1001
2416+
assert called["user_id"] == "u"
24112417

24122418

24132419
@pytest.mark.asyncio
@@ -2692,13 +2698,13 @@ async def test_generate_stream_no_memory_registers_and_streams(monkeypatch):
26922698
AsyncMock(return_value=MagicMock()),
26932699
raising=False,
26942700
)
2695-
2696-
# Capture register
2701+
26972702
registered = {}
26982703

2699-
def fake_register(conv_id, run_info):
2704+
def fake_register(conv_id, run_info, user_id):
27002705
registered["conv_id"] = conv_id
27012706
registered["run_info"] = run_info
2707+
registered["user_id"] = user_id
27022708

27032709
monkeypatch.setattr(
27042710
"backend.services.agent_service.agent_run_manager.register_agent_run",
@@ -2725,6 +2731,8 @@ async def fake_stream_chunks(*_, **__):
27252731
collected.append(d)
27262732

27272733
assert registered.get("conv_id") == 555
2734+
assert registered.get("user_id") == "u"
2735+
assert registered.get("run_info") is not None
27282736
assert collected == ["data: body1\n\n", "data: body2\n\n"]
27292737

27302738

β€Žtest/backend/services/test_northbound_service.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ async def test_stop_chat_success(ctx):
263263
assert result["data"] == "ext-777"
264264
assert result["requestId"] == "req-1"
265265

266+
agent_service_mod.stop_agent_tasks.assert_called_once_with(777, "user-1")
267+
266268

267269
@pytest.mark.asyncio
268270
async def test_list_conversations_maps_ids(ctx):

0 commit comments

Comments
Β (0)