Skip to content
84 changes: 84 additions & 0 deletions examples/01_standalone_sdk/41_model_switching_in_convo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Interactive chat with mid-conversation model switching.

Usage:
uv run examples/01_standalone_sdk/41_model_switching_in_convo.py
"""

import os

from openhands.sdk import LLM, Agent, LocalConversation, Tool
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.tools.terminal import TerminalTool


LLM_API_KEY = os.getenv("LLM_API_KEY")
store = LLMProfileStore()

profiles: dict[str, str] = {
"kimi": "openhands/kimi-k2-0711-preview",
"deepseek": "openhands/deepseek-chat",
"gpt": "openhands/gpt-5.2",
}
for profile_name, model in profiles.items():
store.save(
profile_name,
LLM(model=model, api_key=LLM_API_KEY),
include_secrets=True,
)

llm = LLM(
model=os.getenv("LLM_MODEL", "openhands/claude-sonnet-4-5-20250929"),
api_key=LLM_API_KEY,
)

agent = Agent(llm=llm, tools=[Tool(name=TerminalTool.name)])

conversation = LocalConversation(
agent=agent,
workspace=os.getcwd(),
allow_model_switching=True,
)

print(
"Chat with the agent. Commands:\n"
" /model — show current model and available profiles\n" # noqa: E501
" /model <model_profile_name> — switch to a different model profile\n"
" /model <model_profile_name> [prompt] — switch and send a message in one step\n" # noqa: E501
" /exit — quit\n"
)

try:
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print()
break

if not user_input:
continue
if user_input.lower() == "/exit":
break
conversation.send_message(user_input)
conversation.run()
except Exception:
raise
finally:
# Clean up the profiles we created
for name in profiles.keys():
store.delete(name)

# Inspect metrics across all LLMs (original + switched profiles)
stats = conversation.state.stats
for usage_id, metrics in stats.usage_to_metrics.items():
print(f" [{usage_id}] cost=${metrics.accumulated_cost:.6f}")
for usage in metrics.token_usages:
print(
f" model={usage.model}"
f" prompt={usage.prompt_tokens}"
f" completion={usage.completion_tokens}"
)

combined = stats.get_combined_metrics()
print(f"\nTotal cost (all models): ${combined.accumulated_cost:.6f}")
print(f"EXAMPLE_COST: {combined.accumulated_cost}")
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ConversationState,
)
from openhands.sdk.conversation.stuck_detector import StuckDetector
from openhands.sdk.conversation.switch_model_handler import SwitchModelHandler
from openhands.sdk.conversation.title_utils import generate_conversation_title
from openhands.sdk.conversation.types import (
ConversationCallbackType,
Expand All @@ -33,6 +34,7 @@
from openhands.sdk.event.conversation_error import ConversationErrorEvent
from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback
from openhands.sdk.llm import LLM, Message, TextContent
from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.sdk.llm.llm_registry import LLMRegistry
from openhands.sdk.logger import get_logger
from openhands.sdk.observability.laminar import observe
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
secrets: Mapping[str, SecretValue] | None = None,
delete_on_close: bool = True,
cipher: Cipher | None = None,
allow_model_switching: bool = False,
**_: object,
):
"""Initialize the conversation.
Expand Down Expand Up @@ -133,6 +136,8 @@ def __init__(
state. If provided, secrets are encrypted when saving and
decrypted when loading. If not provided, secrets are redacted
(lost) on serialization.
allow_model_switching: Whether to allow switching between persisted
models using the `/model <MODEL_NAME> [prompt]` command.
"""
super().__init__() # Initialize with span tracking
# Mark cleanup as initiated as early as possible to avoid races or partially
Expand Down Expand Up @@ -245,6 +250,13 @@ def _default_callback(e):
# Agent initialization is deferred to _ensure_agent_ready() for lazy loading
# This ensures plugins are loaded before agent initialization
self.llm_registry = LLMRegistry()
self._model_handler = (
SwitchModelHandler(
llm_registry=self.llm_registry, profile_store=LLMProfileStore()
)
if allow_model_switching
else None
)

# Initialize secrets if provided
if secrets:
Expand Down Expand Up @@ -429,6 +441,42 @@ def _ensure_agent_ready(self) -> None:

self._agent_ready = True

def _handle_model_command(
self,
profile_name: str,
remaining_text: str | None,
sender: str | None,
) -> None:
"""Handle a parsed /model command.

Args:
profile_name: Profile name ("" for info request).
remaining_text: Optional message to send after switching.
sender: Optional sender identifier.
"""
assert self._model_handler is not None

if not profile_name:
logger.info(self._model_handler.get_profiles_info_message(self.agent.llm))
return

try:
self.agent = self._model_handler.switch(self.agent, profile_name)
except (FileNotFoundError, ValueError) as e:
logger.warning(f"Failed to switch {profile_name}: {e}")
return

with self._state:
self._state.agent = self.agent

logger.info(
f"Switched to model profile `{profile_name}` "
f"(model: {self.agent.llm.model})"
)

if remaining_text:
self.send_message(remaining_text, sender=sender)

@observe(name="conversation.send_message")
def send_message(self, message: str | Message, sender: str | None = None) -> None:
"""Send a message to the agent.
Expand All @@ -444,8 +492,15 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non
# Ensure agent is fully initialized (loads plugins and initializes agent)
self._ensure_agent_ready()

# Convert string to Message if needed
if isinstance(message, str):
# Intercept /model command before converting to Message
if self._model_handler is not None:
parsed = self._model_handler.parse(message)
if parsed is not None:
profile_name, remaining_text = parsed
self._handle_model_command(profile_name, remaining_text, sender)
return

message = Message(role="user", content=[TextContent(text=message)])

assert message.role == "user", (
Expand Down
86 changes: 86 additions & 0 deletions openhands-sdk/openhands/sdk/conversation/switch_model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

from openhands.sdk.llm.llm_profile_store import LLMProfileStore
from openhands.sdk.llm.llm_registry import LLMRegistry
from openhands.sdk.logger import get_logger


if TYPE_CHECKING:
from openhands.sdk.agent.base import AgentBase
from openhands.sdk.llm import LLM

logger = get_logger(__name__)


@dataclass
class SwitchModelHandler:
"""Standalone handler for /model command parsing, switching, and info."""

profile_store: LLMProfileStore
llm_registry: LLMRegistry

@staticmethod
def parse(text: str) -> tuple[str, str | None] | None:
"""Parse a /model command from user text.

Returns:
None if the text is not a /model command.
("", None) for bare "/model" (info request).
("profile_name", remaining_text_or_None) otherwise.
"""
stripped = text.strip()
if stripped == "/model":
return "", None
if not stripped.startswith("/model "):
return None
rest = stripped[len("/model ") :].strip()
if not rest:
return "", None
parts = rest.split(None, 1)
profile_name = parts[0]
remaining = parts[1] if len(parts) > 1 else None
return profile_name, remaining

def switch(self, agent: "AgentBase", profile_name: str) -> "AgentBase":
"""Load a model profile and return a new agent with the swapped LLM.

The caller is responsible for storing the returned agent and updating
any external state (e.g. ConversationState).

Args:
agent: Current agent instance.
profile_name: Name of the profile to load from LLMProfileStore.

Returns:
A new AgentBase instance with the switched LLM.

Raises:
FileNotFoundError: If the profile does not exist.
ValueError: If the profile is corrupted or invalid.
"""
if profile_name in self.llm_registry.list_usage_ids():
new_llm = self.llm_registry.get(profile_name)
else:
new_llm = self.profile_store.load(profile_name)
new_llm = new_llm.model_copy(update={"usage_id": profile_name})
self.llm_registry.add(new_llm)

return agent.model_copy(update={"llm": new_llm})

def get_profiles_info_message(self, llm: "LLM") -> str:
"""Return a string with current model and available profiles.

The caller is responsible for emitting the string as an event.
"""
current_model = llm.model
stored_profiles = self.profile_store.list()
registry_profiles = self.llm_registry.list_usage_ids()
profile_names = list(set(stored_profiles).union(set(registry_profiles)))
profile_list = (
", ".join(sorted([Path(p).stem for p in profile_names]))
if profile_names
else "[]"
)
return f"Current model: {current_model}\nAvailable profiles: {profile_list}"
Loading
Loading