Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,38 @@ async def set_conversation_security_analyzer(
return Success()


@conversation_router.post(
"/{conversation_id}/switch_profile",
responses={
400: {"description": "Invalid or corrupted profile"},
404: {"description": "Conversation or profile not found"},
},
)
async def switch_conversation_profile(
conversation_id: UUID,
profile_name: str = Body(..., embed=True),
conversation_service: ConversationService = Depends(get_conversation_service),
) -> Success:
"""Switch the conversation's LLM profile to a named profile."""
event_service = await conversation_service.get_event_service(conversation_id)
if event_service is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
conversation = event_service.get_conversation()
try:
conversation.switch_profile(profile_name)
except FileNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Profile '{profile_name}' not found",
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
return Success()


@conversation_router.patch(
"/{conversation_id}", responses={404: {"description": "Item not found"}}
)
Expand Down
113 changes: 112 additions & 1 deletion tests/agent_server/test_conversation_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for conversation_router.py endpoints."""

from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -1618,3 +1618,114 @@ def test_update_secrets_with_mixed_formats(

finally:
client.app.dependency_overrides.clear()


# --- switch_profile endpoint tests ---


def test_switch_conversation_profile_success(
client, mock_conversation_service, mock_event_service, sample_conversation_id
):
"""Test switch_conversation_profile endpoint with a valid profile."""
mock_conversation = MagicMock()
mock_conversation_service.get_event_service.return_value = mock_event_service
mock_event_service.get_conversation.return_value = mock_conversation

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
response = client.post(
f"/api/conversations/{sample_conversation_id}/switch_profile",
json={"profile_name": "gpt"},
)

assert response.status_code == 200
assert response.json()["success"] is True

mock_conversation_service.get_event_service.assert_called_once_with(
sample_conversation_id
)
mock_event_service.get_conversation.assert_called_once()
mock_conversation.switch_profile.assert_called_once_with("gpt")
finally:
client.app.dependency_overrides.clear()


def test_switch_conversation_profile_not_found(
client, mock_conversation_service, sample_conversation_id
):
"""Test switch_conversation_profile endpoint when conversation is not found."""
mock_conversation_service.get_event_service.return_value = None

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
response = client.post(
f"/api/conversations/{sample_conversation_id}/switch_profile",
json={"profile_name": "gpt"},
)

assert response.status_code == 404
mock_conversation_service.get_event_service.assert_called_once_with(
sample_conversation_id
)
finally:
client.app.dependency_overrides.clear()


def test_switch_conversation_profile_nonexistent_profile(
client, mock_conversation_service, mock_event_service, sample_conversation_id
):
"""Test switch_conversation_profile when the profile does not exist on disk."""
mock_conversation = MagicMock()
mock_conversation.switch_profile.side_effect = FileNotFoundError(
"Profile 'missing' not found"
)
mock_conversation_service.get_event_service.return_value = mock_event_service
mock_event_service.get_conversation.return_value = mock_conversation

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
response = client.post(
f"/api/conversations/{sample_conversation_id}/switch_profile",
json={"profile_name": "missing"},
)

assert response.status_code == 404
assert "missing" in response.json()["detail"]
mock_conversation.switch_profile.assert_called_once_with("missing")
finally:
client.app.dependency_overrides.clear()


def test_switch_conversation_profile_corrupted_profile(
client, mock_conversation_service, mock_event_service, sample_conversation_id
):
"""Test switch_conversation_profile when the profile is corrupted or invalid."""
mock_conversation = MagicMock()
mock_conversation.switch_profile.side_effect = ValueError("Invalid profile format")
mock_conversation_service.get_event_service.return_value = mock_event_service
mock_event_service.get_conversation.return_value = mock_conversation

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
response = client.post(
f"/api/conversations/{sample_conversation_id}/switch_profile",
json={"profile_name": "corrupted"},
)

assert response.status_code == 400
assert "Invalid profile format" in response.json()["detail"]
mock_conversation.switch_profile.assert_called_once_with("corrupted")
finally:
client.app.dependency_overrides.clear()
Loading