Skip to content

Commit 87fb615

Browse files
feat(llm): switch model profile on user message (#2192)
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
1 parent 4c1e1a1 commit 87fb615

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Mid-conversation model switching.
2+
3+
Usage:
4+
uv run examples/01_standalone_sdk/44_model_switching_in_convo.py
5+
"""
6+
7+
import os
8+
9+
from openhands.sdk import LLM, Agent, LocalConversation, Tool
10+
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
11+
from openhands.tools.terminal import TerminalTool
12+
13+
14+
LLM_API_KEY = os.getenv("LLM_API_KEY")
15+
store = LLMProfileStore()
16+
17+
store.save(
18+
"gpt",
19+
LLM(model="openhands/gpt-5.2", api_key=LLM_API_KEY),
20+
include_secrets=True,
21+
)
22+
23+
agent = Agent(
24+
llm=LLM(
25+
model=os.getenv("LLM_MODEL", "openhands/claude-sonnet-4-5-20250929"),
26+
api_key=LLM_API_KEY,
27+
),
28+
tools=[Tool(name=TerminalTool.name)],
29+
)
30+
conversation = LocalConversation(agent=agent, workspace=os.getcwd())
31+
32+
# Send a message with the default model
33+
conversation.send_message("Say hello in one sentence.")
34+
conversation.run()
35+
36+
# Switch to a different model and send another message
37+
conversation.switch_profile("gpt")
38+
print(f"Switched to: {conversation.agent.llm.model}")
39+
40+
conversation.send_message("Say goodbye in one sentence.")
41+
conversation.run()
42+
43+
# Print metrics per model
44+
for usage_id, metrics in conversation.state.stats.usage_to_metrics.items():
45+
print(f" [{usage_id}] cost=${metrics.accumulated_cost:.6f}")
46+
47+
combined = conversation.state.stats.get_combined_metrics()
48+
print(f"Total cost: ${combined.accumulated_cost:.6f}")
49+
print(f"EXAMPLE_COST: {combined.accumulated_cost}")
50+
51+
store.delete("gpt")

openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from openhands.sdk.event.conversation_error import ConversationErrorEvent
3434
from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback
3535
from openhands.sdk.llm import LLM, Message, TextContent
36+
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
3637
from openhands.sdk.llm.llm_registry import LLMRegistry
3738
from openhands.sdk.logger import get_logger
3839
from openhands.sdk.observability.laminar import observe
@@ -250,6 +251,7 @@ def _default_callback(e):
250251
# Agent initialization is deferred to _ensure_agent_ready() for lazy loading
251252
# This ensures plugins are loaded before agent initialization
252253
self.llm_registry = LLMRegistry()
254+
self._profile_store = LLMProfileStore()
253255

254256
# Initialize secrets if provided
255257
if secrets:
@@ -464,11 +466,37 @@ def _ensure_agent_ready(self) -> None:
464466

465467
# Register LLMs in the registry (still holding lock)
466468
self.llm_registry.subscribe(self._state.stats.register_llm)
469+
registered = set(self.llm_registry.list_usage_ids())
467470
for llm in list(self.agent.get_all_llms()):
468-
self.llm_registry.add(llm)
471+
if llm.usage_id not in registered:
472+
self.llm_registry.add(llm)
469473

470474
self._agent_ready = True
471475

476+
def switch_profile(self, profile_name: str) -> None:
477+
"""Switch the agent's LLM to a named profile.
478+
479+
Loads the profile from the LLMProfileStore (cached in the registry
480+
after the first load) and updates the agent and conversation state.
481+
482+
Args:
483+
profile_name: Name of a profile previously saved via LLMProfileStore.
484+
485+
Raises:
486+
FileNotFoundError: If the profile does not exist.
487+
ValueError: If the profile is corrupted or invalid.
488+
"""
489+
usage_id = f"profile:{profile_name}"
490+
try:
491+
new_llm = self.llm_registry.get(usage_id)
492+
except KeyError:
493+
new_llm = self._profile_store.load(profile_name)
494+
new_llm = new_llm.model_copy(update={"usage_id": usage_id})
495+
self.llm_registry.add(new_llm)
496+
with self._state:
497+
self.agent = self.agent.model_copy(update={"llm": new_llm})
498+
self._state.agent = self.agent
499+
472500
@observe(name="conversation.send_message")
473501
def send_message(self, message: str | Message, sender: str | None = None) -> None:
474502
"""Send a message to the agent.
@@ -484,7 +512,6 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non
484512
# Ensure agent is fully initialized (loads plugins and initializes agent)
485513
self._ensure_agent_ready()
486514

487-
# Convert string to Message if needed
488515
if isinstance(message, str):
489516
message = Message(role="user", content=[TextContent(text=message)])
490517

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from openhands.sdk import LLM, LocalConversation
6+
from openhands.sdk.agent import Agent
7+
from openhands.sdk.llm import llm_profile_store
8+
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
9+
from openhands.sdk.testing import TestLLM
10+
11+
12+
def _make_llm(model: str, usage_id: str) -> LLM:
13+
return TestLLM.from_messages([], model=model, usage_id=usage_id)
14+
15+
16+
@pytest.fixture()
17+
def profile_store(tmp_path, monkeypatch):
18+
"""
19+
Create a temp profile store with 'fast' and
20+
'slow' profiles saved via _make_llm.
21+
"""
22+
23+
profile_dir = tmp_path / "profiles"
24+
profile_dir.mkdir()
25+
monkeypatch.setattr(llm_profile_store, "_DEFAULT_PROFILE_DIR", profile_dir)
26+
27+
store = LLMProfileStore(base_dir=profile_dir)
28+
store.save("fast", _make_llm("fast-model", "fast"))
29+
store.save("slow", _make_llm("slow-model", "slow"))
30+
return store
31+
32+
33+
def _make_conversation() -> LocalConversation:
34+
return LocalConversation(
35+
agent=Agent(
36+
llm=_make_llm("default-model", "test-llm"),
37+
tools=[],
38+
),
39+
workspace=Path.cwd(),
40+
)
41+
42+
43+
def test_switch_profile(profile_store):
44+
"""switch_profile switches the agent's LLM."""
45+
conv = _make_conversation()
46+
conv.switch_profile("fast")
47+
assert conv.agent.llm.model == "fast-model"
48+
conv.switch_profile("slow")
49+
assert conv.agent.llm.model == "slow-model"
50+
51+
52+
def test_switch_profile_updates_state(profile_store):
53+
"""switch_profile updates conversation state agent."""
54+
conv = _make_conversation()
55+
conv.switch_profile("fast")
56+
assert conv.state.agent.llm.model == "fast-model"
57+
58+
59+
def test_switch_between_profiles(profile_store):
60+
"""Switch fast -> slow -> fast, verify model changes each time."""
61+
conv = _make_conversation()
62+
63+
conv.switch_profile("fast")
64+
assert conv.agent.llm.model == "fast-model"
65+
66+
conv.switch_profile("slow")
67+
assert conv.agent.llm.model == "slow-model"
68+
69+
conv.switch_profile("fast")
70+
assert conv.agent.llm.model == "fast-model"
71+
72+
73+
def test_switch_reuses_registry_entry(profile_store):
74+
"""Switching back to a profile reuses the same registry LLM object."""
75+
conv = _make_conversation()
76+
77+
conv.switch_profile("fast")
78+
llm_first = conv.llm_registry.get("profile:fast")
79+
80+
conv.switch_profile("slow")
81+
conv.switch_profile("fast")
82+
llm_second = conv.llm_registry.get("profile:fast")
83+
84+
assert llm_first is llm_second
85+
86+
87+
def test_switch_nonexistent_raises(profile_store):
88+
"""Switching to a nonexistent profile raises FileNotFoundError."""
89+
conv = _make_conversation()
90+
with pytest.raises(FileNotFoundError):
91+
conv.switch_profile("nonexistent")
92+
assert conv.agent.llm.model == "default-model"
93+
assert conv.state.agent.llm.model == "default-model"
94+
95+
96+
def test_switch_then_send_message(profile_store):
97+
"""switch_profile followed by send_message doesn't crash on registry collision."""
98+
conv = _make_conversation()
99+
conv.switch_profile("fast")
100+
# send_message triggers _ensure_agent_ready which re-registers agent LLMs;
101+
# the switched LLM must not cause a duplicate registration error.
102+
conv.send_message("hello")

0 commit comments

Comments
 (0)