Skip to content

Commit 74112a9

Browse files
authored
feat(agent-server): expose model switching (switch_profile) in agent-server REST API (#2795)
1 parent df6f5a5 commit 74112a9

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

openhands-agent-server/openhands/agent_server/conversation_router.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,38 @@ async def set_conversation_security_analyzer(
288288
return Success()
289289

290290

291+
@conversation_router.post(
292+
"/{conversation_id}/switch_profile",
293+
responses={
294+
400: {"description": "Invalid or corrupted profile"},
295+
404: {"description": "Conversation or profile not found"},
296+
},
297+
)
298+
async def switch_conversation_profile(
299+
conversation_id: UUID,
300+
profile_name: str = Body(..., embed=True),
301+
conversation_service: ConversationService = Depends(get_conversation_service),
302+
) -> Success:
303+
"""Switch the conversation's LLM profile to a named profile."""
304+
event_service = await conversation_service.get_event_service(conversation_id)
305+
if event_service is None:
306+
raise HTTPException(status.HTTP_404_NOT_FOUND)
307+
conversation = event_service.get_conversation()
308+
try:
309+
conversation.switch_profile(profile_name)
310+
except FileNotFoundError:
311+
raise HTTPException(
312+
status_code=status.HTTP_404_NOT_FOUND,
313+
detail=f"Profile '{profile_name}' not found",
314+
)
315+
except ValueError as e:
316+
raise HTTPException(
317+
status_code=status.HTTP_400_BAD_REQUEST,
318+
detail=str(e),
319+
)
320+
return Success()
321+
322+
291323
@conversation_router.patch(
292324
"/{conversation_id}", responses={404: {"description": "Item not found"}}
293325
)

tests/agent_server/test_conversation_router.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for conversation_router.py endpoints."""
22

3-
from unittest.mock import AsyncMock
3+
from unittest.mock import AsyncMock, MagicMock
44
from uuid import uuid4
55

66
import pytest
@@ -1618,3 +1618,114 @@ def test_update_secrets_with_mixed_formats(
16181618

16191619
finally:
16201620
client.app.dependency_overrides.clear()
1621+
1622+
1623+
# --- switch_profile endpoint tests ---
1624+
1625+
1626+
def test_switch_conversation_profile_success(
1627+
client, mock_conversation_service, mock_event_service, sample_conversation_id
1628+
):
1629+
"""Test switch_conversation_profile endpoint with a valid profile."""
1630+
mock_conversation = MagicMock()
1631+
mock_conversation_service.get_event_service.return_value = mock_event_service
1632+
mock_event_service.get_conversation.return_value = mock_conversation
1633+
1634+
client.app.dependency_overrides[get_conversation_service] = (
1635+
lambda: mock_conversation_service
1636+
)
1637+
1638+
try:
1639+
response = client.post(
1640+
f"/api/conversations/{sample_conversation_id}/switch_profile",
1641+
json={"profile_name": "gpt"},
1642+
)
1643+
1644+
assert response.status_code == 200
1645+
assert response.json()["success"] is True
1646+
1647+
mock_conversation_service.get_event_service.assert_called_once_with(
1648+
sample_conversation_id
1649+
)
1650+
mock_event_service.get_conversation.assert_called_once()
1651+
mock_conversation.switch_profile.assert_called_once_with("gpt")
1652+
finally:
1653+
client.app.dependency_overrides.clear()
1654+
1655+
1656+
def test_switch_conversation_profile_not_found(
1657+
client, mock_conversation_service, sample_conversation_id
1658+
):
1659+
"""Test switch_conversation_profile endpoint when conversation is not found."""
1660+
mock_conversation_service.get_event_service.return_value = None
1661+
1662+
client.app.dependency_overrides[get_conversation_service] = (
1663+
lambda: mock_conversation_service
1664+
)
1665+
1666+
try:
1667+
response = client.post(
1668+
f"/api/conversations/{sample_conversation_id}/switch_profile",
1669+
json={"profile_name": "gpt"},
1670+
)
1671+
1672+
assert response.status_code == 404
1673+
mock_conversation_service.get_event_service.assert_called_once_with(
1674+
sample_conversation_id
1675+
)
1676+
finally:
1677+
client.app.dependency_overrides.clear()
1678+
1679+
1680+
def test_switch_conversation_profile_nonexistent_profile(
1681+
client, mock_conversation_service, mock_event_service, sample_conversation_id
1682+
):
1683+
"""Test switch_conversation_profile when the profile does not exist on disk."""
1684+
mock_conversation = MagicMock()
1685+
mock_conversation.switch_profile.side_effect = FileNotFoundError(
1686+
"Profile 'missing' not found"
1687+
)
1688+
mock_conversation_service.get_event_service.return_value = mock_event_service
1689+
mock_event_service.get_conversation.return_value = mock_conversation
1690+
1691+
client.app.dependency_overrides[get_conversation_service] = (
1692+
lambda: mock_conversation_service
1693+
)
1694+
1695+
try:
1696+
response = client.post(
1697+
f"/api/conversations/{sample_conversation_id}/switch_profile",
1698+
json={"profile_name": "missing"},
1699+
)
1700+
1701+
assert response.status_code == 404
1702+
assert "missing" in response.json()["detail"]
1703+
mock_conversation.switch_profile.assert_called_once_with("missing")
1704+
finally:
1705+
client.app.dependency_overrides.clear()
1706+
1707+
1708+
def test_switch_conversation_profile_corrupted_profile(
1709+
client, mock_conversation_service, mock_event_service, sample_conversation_id
1710+
):
1711+
"""Test switch_conversation_profile when the profile is corrupted or invalid."""
1712+
mock_conversation = MagicMock()
1713+
mock_conversation.switch_profile.side_effect = ValueError("Invalid profile format")
1714+
mock_conversation_service.get_event_service.return_value = mock_event_service
1715+
mock_event_service.get_conversation.return_value = mock_conversation
1716+
1717+
client.app.dependency_overrides[get_conversation_service] = (
1718+
lambda: mock_conversation_service
1719+
)
1720+
1721+
try:
1722+
response = client.post(
1723+
f"/api/conversations/{sample_conversation_id}/switch_profile",
1724+
json={"profile_name": "corrupted"},
1725+
)
1726+
1727+
assert response.status_code == 400
1728+
assert "Invalid profile format" in response.json()["detail"]
1729+
mock_conversation.switch_profile.assert_called_once_with("corrupted")
1730+
finally:
1731+
client.app.dependency_overrides.clear()

0 commit comments

Comments
 (0)