Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
ConversationInfo,
ConversationPage,
ConversationSortOrder,
GenerateTitleRequest,
GenerateTitleResponse,
SendMessageRequest,
SetConfirmationPolicyRequest,
SetSecurityAnalyzerRequest,
Expand Down Expand Up @@ -272,24 +270,6 @@ async def update_conversation(
return Success()


@conversation_router.post(
"/{conversation_id}/generate_title",
responses={404: {"description": "Item not found"}},
)
async def generate_conversation_title(
conversation_id: UUID,
request: GenerateTitleRequest,
conversation_service: ConversationService = Depends(get_conversation_service),
) -> GenerateTitleResponse:
"""Generate a title for the conversation using LLM."""
title = await conversation_service.generate_conversation_title(
conversation_id, request.max_length, request.llm
)
if title is None:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR)
return GenerateTitleResponse(title=title)


@conversation_router.post(
"/{conversation_id}/ask_agent",
responses={404: {"description": "Item not found"}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from openhands.agent_server.pub_sub import Subscriber
from openhands.agent_server.server_details_router import update_last_execution_time
from openhands.agent_server.utils import safe_rmtree, utc_now
from openhands.sdk import LLM, Event, Message
from openhands.sdk import Event, Message
from openhands.sdk.conversation.state import (
ConversationExecutionStatus,
ConversationState,
)
from openhands.sdk.event import MessageEvent
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
from openhands.sdk.utils.cipher import Cipher

Expand Down Expand Up @@ -371,20 +372,6 @@ async def get_event_service(self, conversation_id: UUID) -> EventService | None:
raise ValueError("inactive_service")
return self._event_services.get(conversation_id)

async def generate_conversation_title(
self, conversation_id: UUID, max_length: int = 50, llm: LLM | None = None
) -> str | None:
"""Generate a title for the conversation using LLM."""
if self._event_services is None:
raise ValueError("inactive_service")
event_service = self._event_services.get(conversation_id)
if event_service is None:
return None

# Delegate to EventService to avoid accessing private conversation internals
title = await event_service.generate_title(llm=llm, max_length=max_length)
return title

async def ask_agent(self, conversation_id: UUID, question: str) -> str | None:
"""Ask the agent a simple question without affecting conversation state."""
if self._event_services is None:
Expand Down Expand Up @@ -505,6 +492,10 @@ async def _start_event_service(self, stored: StoredConversation) -> EventService
)
# Create subscribers...
await event_service.subscribe_to_events(_EventSubscriber(service=event_service))
if stored.autotitle and stored.title is None:
await event_service.subscribe_to_events(
AutoTitleSubscriber(service=event_service)
)
asyncio.gather(
*[
event_service.subscribe_to_events(
Expand Down Expand Up @@ -548,6 +539,35 @@ async def __call__(self, _event: Event):
update_last_execution_time()


@dataclass
class AutoTitleSubscriber(Subscriber):
service: EventService

async def __call__(self, event: Event) -> None:
# Only act on incoming user messages
if not isinstance(event, MessageEvent) or event.source != "user":
return
# Guard: skip if a title was already set (e.g. by a concurrent task)
if self.service.stored.title is not None:
return

async def _generate_and_save() -> None:
try:
title = await self.service.generate_title()
if title and self.service.stored.title is None:
self.service.stored.title = title
self.service.stored.updated_at = utc_now()
await self.service.save_meta()
except Exception:
logger.warning(
f"Auto-title generation failed for "
f"conversation {self.service.stored.id}",
exc_info=True,
)

asyncio.create_task(_generate_and_save())


@dataclass
class WebhookSubscriber(Subscriber):
conversation_id: UUID
Expand Down
26 changes: 8 additions & 18 deletions openhands-agent-server/openhands/agent_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field, field_validator

from openhands.agent_server.utils import OpenHandsUUID, utc_now
from openhands.sdk import LLM, AgentBase, Event, ImageContent, Message, TextContent
from openhands.sdk import AgentBase, Event, ImageContent, Message, TextContent
from openhands.sdk.conversation.state import (
ConversationExecutionStatus,
ConversationState,
Expand Down Expand Up @@ -126,6 +126,13 @@ class StartConversationRequest(BaseModel):
"hooks."
),
)
autotitle: bool = Field(
default=True,
description=(
"If true, automatically generate a title for the conversation from "
"the first user message using the conversation's LLM."
),
)


class StoredConversation(StartConversationRequest):
Expand Down Expand Up @@ -250,23 +257,6 @@ class UpdateConversationRequest(BaseModel):
)


class GenerateTitleRequest(BaseModel):
"""Payload to generate a title for a conversation."""

max_length: int = Field(
default=50, ge=1, le=200, description="Maximum length of the generated title"
)
llm: LLM | None = Field(
default=None, description="Optional LLM to use for title generation"
)


class GenerateTitleResponse(BaseModel):
"""Response containing the generated conversation title."""

title: str = Field(description="The generated title for the conversation")


class AskAgentRequest(BaseModel):
"""Payload to ask the agent a simple question."""

Expand Down
205 changes: 69 additions & 136 deletions tests/agent_server/test_conversation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,142 +1042,6 @@ def test_update_conversation_invalid_title(
client.app.dependency_overrides.clear()


def test_generate_conversation_title_success(
client, mock_conversation_service, sample_conversation_id
):
"""Test generate_conversation_title endpoint with successful generation."""

# Mock the service response
mock_conversation_service.generate_conversation_title.return_value = (
"Generated Title"
)

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {"max_length": 30}

response = client.post(
f"/api/conversations/{sample_conversation_id}/generate_title",
json=request_data,
)

assert response.status_code == 200
data = response.json()
assert data["title"] == "Generated Title"

# Verify service was called with correct parameters
mock_conversation_service.generate_conversation_title.assert_called_once()
call_args = mock_conversation_service.generate_conversation_title.call_args
assert call_args[0][0] == sample_conversation_id
assert call_args[0][1] == 30 # max_length
assert call_args[0][2] is None # llm (default)
finally:
client.app.dependency_overrides.clear()


def test_generate_conversation_title_with_llm(
client, mock_conversation_service, sample_conversation_id
):
"""Test generate_conversation_title endpoint with custom LLM."""

# Mock the service response
mock_conversation_service.generate_conversation_title.return_value = (
"Custom LLM Title"
)

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {
"max_length": 40,
"llm": {
"model": "gpt-3.5-turbo",
"api_key": "custom-key",
"usage_id": "custom-llm",
},
}

response = client.post(
f"/api/conversations/{sample_conversation_id}/generate_title",
json=request_data,
)

assert response.status_code == 200
data = response.json()
assert data["title"] == "Custom LLM Title"

# Verify service was called
mock_conversation_service.generate_conversation_title.assert_called_once()
call_args = mock_conversation_service.generate_conversation_title.call_args
assert call_args[0][0] == sample_conversation_id
assert call_args[0][1] == 40 # max_length
assert call_args[0][2] is not None # llm provided
finally:
client.app.dependency_overrides.clear()


def test_generate_conversation_title_failure(
client, mock_conversation_service, sample_conversation_id
):
"""Test generate_conversation_title endpoint with generation failure."""

# Mock the service response - generation failed
mock_conversation_service.generate_conversation_title.return_value = None

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {"max_length": 50}

response = client.post(
f"/api/conversations/{sample_conversation_id}/generate_title",
json=request_data,
)

assert response.status_code == 500 # Internal Server Error

# Verify service was called
mock_conversation_service.generate_conversation_title.assert_called_once()
finally:
client.app.dependency_overrides.clear()


def test_generate_conversation_title_invalid_params(
client, mock_conversation_service, sample_conversation_id
):
"""Test generate_conversation_title endpoint with invalid parameters."""

client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
# Test with max_length too low
request_data = {"max_length": 0}
response = client.post(
f"/api/conversations/{sample_conversation_id}/generate_title",
json=request_data,
)
assert response.status_code == 422 # Validation error

# Test with max_length too high
request_data = {"max_length": 201}
response = client.post(
f"/api/conversations/{sample_conversation_id}/generate_title",
json=request_data,
)
assert response.status_code == 422 # Validation error
finally:
client.app.dependency_overrides.clear()


def test_start_conversation_with_tool_module_qualnames(
client, mock_conversation_service, sample_conversation_info
):
Expand Down Expand Up @@ -1284,6 +1148,75 @@ def test_start_conversation_without_tool_module_qualnames(
client.app.dependency_overrides.clear()


def test_start_conversation_autotitle_defaults_to_true(
client, mock_conversation_service, sample_conversation_info
):
"""autotitle defaults to True when not supplied in the request."""
mock_conversation_service.start_conversation.return_value = (
sample_conversation_info,
True,
)
client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {
"agent": {
"llm": {
"model": "gpt-4o",
"api_key": "test-key",
"usage_id": "test-llm",
},
"tools": [{"name": "TerminalTool"}],
},
"workspace": {"working_dir": "/tmp/test"},
}
response = client.post("/api/conversations", json=request_data)

assert response.status_code == 201
call_args = mock_conversation_service.start_conversation.call_args
request_arg = call_args[0][0]
assert request_arg.autotitle is True
finally:
client.app.dependency_overrides.clear()


def test_start_conversation_autotitle_false(
client, mock_conversation_service, sample_conversation_info
):
"""autotitle=False is forwarded correctly to the service."""
mock_conversation_service.start_conversation.return_value = (
sample_conversation_info,
True,
)
client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {
"agent": {
"llm": {
"model": "gpt-4o",
"api_key": "test-key",
"usage_id": "test-llm",
},
"tools": [{"name": "TerminalTool"}],
},
"workspace": {"working_dir": "/tmp/test"},
"autotitle": False,
}
response = client.post("/api/conversations", json=request_data)

assert response.status_code == 201
call_args = mock_conversation_service.start_conversation.call_args
request_arg = call_args[0][0]
assert request_arg.autotitle is False
finally:
client.app.dependency_overrides.clear()


def test_set_conversation_security_analyzer_success(
client,
sample_conversation_id,
Expand Down
Loading
Loading