Skip to content
51 changes: 51 additions & 0 deletions examples/01_standalone_sdk/44_model_switching_in_convo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Mid-conversation model switching.

Usage:
uv run examples/01_standalone_sdk/44_model_switching_in_convo.py
"""

import os

from openhands.sdk import LLM, Agent, LocalConversation, Tool
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.tools.terminal import TerminalTool


LLM_API_KEY = os.getenv("LLM_API_KEY")
store = LLMProfileStore()

store.save(
"gpt",
LLM(model="openhands/gpt-5.2", api_key=LLM_API_KEY),
include_secrets=True,
)

agent = Agent(
llm=LLM(
model=os.getenv("LLM_MODEL", "openhands/claude-sonnet-4-5-20250929"),
api_key=LLM_API_KEY,
),
tools=[Tool(name=TerminalTool.name)],
)
conversation = LocalConversation(agent=agent, workspace=os.getcwd())

# Send a message with the default model
conversation.send_message("Say hello in one sentence.")
conversation.run()

# Switch to a different model and send another message
conversation.switch_profile("gpt")
print(f"Switched to: {conversation.agent.llm.model}")

conversation.send_message("Say goodbye in one sentence.")
conversation.run()

# Print metrics per model
for usage_id, metrics in conversation.state.stats.usage_to_metrics.items():
print(f" [{usage_id}] cost=${metrics.accumulated_cost:.6f}")

combined = conversation.state.stats.get_combined_metrics()
print(f"Total cost: ${combined.accumulated_cost:.6f}")
print(f"EXAMPLE_COST: {combined.accumulated_cost}")

store.delete("gpt")
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from openhands.sdk.event.conversation_error import ConversationErrorEvent
from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback
from openhands.sdk.llm import LLM, Message, TextContent
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.sdk.llm.llm_registry import LLMRegistry
from openhands.sdk.logger import get_logger
from openhands.sdk.observability.laminar import observe
Expand Down Expand Up @@ -250,6 +251,7 @@ def _default_callback(e):
# Agent initialization is deferred to _ensure_agent_ready() for lazy loading
# This ensures plugins are loaded before agent initialization
self.llm_registry = LLMRegistry()
self._profile_store = LLMProfileStore()

# Initialize secrets if provided
if secrets:
Expand Down Expand Up @@ -464,11 +466,37 @@ def _ensure_agent_ready(self) -> None:

# Register LLMs in the registry (still holding lock)
self.llm_registry.subscribe(self._state.stats.register_llm)
registered = set(self.llm_registry.list_usage_ids())
for llm in list(self.agent.get_all_llms()):
self.llm_registry.add(llm)
if llm.usage_id not in registered:
self.llm_registry.add(llm)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, not about this PR, but this looks curious actually. Are we saying that the agent may have gotten an LLM instance that wasn't registered? Do we know that?


self._agent_ready = True

def switch_profile(self, profile_name: str) -> None:
"""Switch the agent's LLM to a named profile.

Loads the profile from the LLMProfileStore (cached in the registry
after the first load) and updates the agent and conversation state.

Args:
profile_name: Name of a profile previously saved via LLMProfileStore.

Raises:
FileNotFoundError: If the profile does not exist.
ValueError: If the profile is corrupted or invalid.
"""
usage_id = f"profile:{profile_name}"
try:
new_llm = self.llm_registry.get(usage_id)
except KeyError:
new_llm = self._profile_store.load(profile_name)
new_llm = new_llm.model_copy(update={"usage_id": usage_id})
self.llm_registry.add(new_llm)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really only a tiny thought, this code could be out of here and in the registry I think

Not for this PR, just thinking of responsibilities; now the Conversation knows both llm registry and llm profile store and does stuff between them, even though those two are so related they could arguably be the same. Maybe worth thinking about it later, sorry

with self._state:
self.agent = self.agent.model_copy(update={"llm": new_llm})
self._state.agent = self.agent

@observe(name="conversation.send_message")
def send_message(self, message: str | Message, sender: str | None = None) -> None:
"""Send a message to the agent.
Expand All @@ -484,7 +512,6 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non
# Ensure agent is fully initialized (loads plugins and initializes agent)
self._ensure_agent_ready()

# Convert string to Message if needed
if isinstance(message, str):
message = Message(role="user", content=[TextContent(text=message)])

Expand Down
102 changes: 102 additions & 0 deletions tests/sdk/conversation/test_switch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pathlib import Path

import pytest

from openhands.sdk import LLM, LocalConversation
from openhands.sdk.agent import Agent
from openhands.sdk.llm import llm_profile_store
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.sdk.testing import TestLLM


def _make_llm(model: str, usage_id: str) -> LLM:
return TestLLM.from_messages([], model=model, usage_id=usage_id)


@pytest.fixture()
def profile_store(tmp_path, monkeypatch):
"""
Create a temp profile store with 'fast' and
'slow' profiles saved via _make_llm.
"""

profile_dir = tmp_path / "profiles"
profile_dir.mkdir()
monkeypatch.setattr(llm_profile_store, "_DEFAULT_PROFILE_DIR", profile_dir)

store = LLMProfileStore(base_dir=profile_dir)
store.save("fast", _make_llm("fast-model", "fast"))
store.save("slow", _make_llm("slow-model", "slow"))
return store


def _make_conversation() -> LocalConversation:
return LocalConversation(
agent=Agent(
llm=_make_llm("default-model", "test-llm"),
tools=[],
),
workspace=Path.cwd(),
)


def test_switch_profile(profile_store):
"""switch_profile switches the agent's LLM."""
conv = _make_conversation()
conv.switch_profile("fast")
assert conv.agent.llm.model == "fast-model"
conv.switch_profile("slow")
assert conv.agent.llm.model == "slow-model"


def test_switch_profile_updates_state(profile_store):
"""switch_profile updates conversation state agent."""
conv = _make_conversation()
conv.switch_profile("fast")
assert conv.state.agent.llm.model == "fast-model"


def test_switch_between_profiles(profile_store):
"""Switch fast -> slow -> fast, verify model changes each time."""
conv = _make_conversation()

conv.switch_profile("fast")
assert conv.agent.llm.model == "fast-model"

conv.switch_profile("slow")
assert conv.agent.llm.model == "slow-model"

conv.switch_profile("fast")
assert conv.agent.llm.model == "fast-model"


def test_switch_reuses_registry_entry(profile_store):
"""Switching back to a profile reuses the same registry LLM object."""
conv = _make_conversation()

conv.switch_profile("fast")
llm_first = conv.llm_registry.get("profile:fast")

conv.switch_profile("slow")
conv.switch_profile("fast")
llm_second = conv.llm_registry.get("profile:fast")

assert llm_first is llm_second


def test_switch_nonexistent_raises(profile_store):
"""Switching to a nonexistent profile raises FileNotFoundError."""
conv = _make_conversation()
with pytest.raises(FileNotFoundError):
conv.switch_profile("nonexistent")
assert conv.agent.llm.model == "default-model"
assert conv.state.agent.llm.model == "default-model"


def test_switch_then_send_message(profile_store):
"""switch_profile followed by send_message doesn't crash on registry collision."""
conv = _make_conversation()
conv.switch_profile("fast")
# send_message triggers _ensure_agent_ready which re-registers agent LLMs;
# the switched LLM must not cause a duplicate registration error.
conv.send_message("hello")
Loading