-
Notifications
You must be signed in to change notification settings - Fork 218
feat(llm): switch model profile on user message #2192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
fbec87a
0e254ea
ab6a4e4
eb49ea4
d4beb1b
b088ced
e67d28f
4132679
e9ce849
dfb4179
c8d8105
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| """Mid-conversation model switching. | ||
|
|
||
| Usage: | ||
| uv run examples/01_standalone_sdk/44_model_switching_in_convo.py | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| from openhands.sdk import LLM, Agent, LocalConversation, Tool | ||
| from openhands.sdk.llm.llm_profile_store import LLMProfileStore | ||
| from openhands.tools.terminal import TerminalTool | ||
|
|
||
|
|
||
| LLM_API_KEY = os.getenv("LLM_API_KEY") | ||
| store = LLMProfileStore() | ||
|
|
||
| store.save( | ||
| "gpt", | ||
| LLM(model="openhands/gpt-5.2", api_key=LLM_API_KEY), | ||
| include_secrets=True, | ||
| ) | ||
|
|
||
| agent = Agent( | ||
| llm=LLM( | ||
| model=os.getenv("LLM_MODEL", "openhands/claude-sonnet-4-5-20250929"), | ||
| api_key=LLM_API_KEY, | ||
| ), | ||
| tools=[Tool(name=TerminalTool.name)], | ||
| ) | ||
| conversation = LocalConversation(agent=agent, workspace=os.getcwd()) | ||
|
|
||
| # Send a message with the default model | ||
| conversation.send_message("Say hello in one sentence.") | ||
| conversation.run() | ||
|
|
||
| # Switch to a different model and send another message | ||
| conversation.switch_profile("gpt") | ||
| print(f"Switched to: {conversation.agent.llm.model}") | ||
|
|
||
| conversation.send_message("Say goodbye in one sentence.") | ||
| conversation.run() | ||
|
|
||
| # Print metrics per model | ||
| for usage_id, metrics in conversation.state.stats.usage_to_metrics.items(): | ||
| print(f" [{usage_id}] cost=${metrics.accumulated_cost:.6f}") | ||
|
|
||
| combined = conversation.state.stats.get_combined_metrics() | ||
| print(f"Total cost: ${combined.accumulated_cost:.6f}") | ||
| print(f"EXAMPLE_COST: {combined.accumulated_cost}") | ||
|
|
||
| store.delete("gpt") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ | |
| from openhands.sdk.event.conversation_error import ConversationErrorEvent | ||
| from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback | ||
| from openhands.sdk.llm import LLM, Message, TextContent | ||
| from openhands.sdk.llm.llm_profile_store import LLMProfileStore | ||
| from openhands.sdk.llm.llm_registry import LLMRegistry | ||
| from openhands.sdk.logger import get_logger | ||
| from openhands.sdk.observability.laminar import observe | ||
|
|
@@ -250,6 +251,7 @@ def _default_callback(e): | |
| # Agent initialization is deferred to _ensure_agent_ready() for lazy loading | ||
| # This ensures plugins are loaded before agent initialization | ||
| self.llm_registry = LLMRegistry() | ||
| self._profile_store = LLMProfileStore() | ||
|
|
||
| # Initialize secrets if provided | ||
| if secrets: | ||
|
|
@@ -464,11 +466,37 @@ def _ensure_agent_ready(self) -> None: | |
|
|
||
| # Register LLMs in the registry (still holding lock) | ||
| self.llm_registry.subscribe(self._state.stats.register_llm) | ||
| registered = set(self.llm_registry.list_usage_ids()) | ||
| for llm in list(self.agent.get_all_llms()): | ||
| self.llm_registry.add(llm) | ||
| if llm.usage_id not in registered: | ||
| self.llm_registry.add(llm) | ||
|
|
||
| self._agent_ready = True | ||
|
|
||
| def switch_profile(self, profile_name: str) -> None: | ||
| """Switch the agent's LLM to a named profile. | ||
|
|
||
VascoSch92 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Loads the profile from the LLMProfileStore (cached in the registry | ||
| after the first load) and updates the agent and conversation state. | ||
|
|
||
| Args: | ||
| profile_name: Name of a profile previously saved via LLMProfileStore. | ||
|
|
||
| Raises: | ||
| FileNotFoundError: If the profile does not exist. | ||
| ValueError: If the profile is corrupted or invalid. | ||
| """ | ||
| usage_id = f"profile:{profile_name}" | ||
| try: | ||
| new_llm = self.llm_registry.get(usage_id) | ||
| except KeyError: | ||
| new_llm = self._profile_store.load(profile_name) | ||
| new_llm = new_llm.model_copy(update={"usage_id": usage_id}) | ||
| self.llm_registry.add(new_llm) | ||
VascoSch92 marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really only a tiny thought, this code could be out of here and in the registry I think Not for this PR, just thinking of responsibilities; now the Conversation knows both |
||
| with self._state: | ||
| self.agent = self.agent.model_copy(update={"llm": new_llm}) | ||
| self._state.agent = self.agent | ||
VascoSch92 marked this conversation as resolved.
Show resolved
Hide resolved
VascoSch92 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @observe(name="conversation.send_message") | ||
| def send_message(self, message: str | Message, sender: str | None = None) -> None: | ||
| """Send a message to the agent. | ||
|
|
@@ -484,7 +512,6 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non | |
| # Ensure agent is fully initialized (loads plugins and initializes agent) | ||
| self._ensure_agent_ready() | ||
|
|
||
| # Convert string to Message if needed | ||
| if isinstance(message, str): | ||
| message = Message(role="user", content=[TextContent(text=message)]) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from openhands.sdk import LLM, LocalConversation | ||
| from openhands.sdk.agent import Agent | ||
| from openhands.sdk.llm import llm_profile_store | ||
| from openhands.sdk.llm.llm_profile_store import LLMProfileStore | ||
| from openhands.sdk.testing import TestLLM | ||
|
|
||
|
|
||
| def _make_llm(model: str, usage_id: str) -> LLM: | ||
| return TestLLM.from_messages([], model=model, usage_id=usage_id) | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def profile_store(tmp_path, monkeypatch): | ||
| """ | ||
| Create a temp profile store with 'fast' and | ||
| 'slow' profiles saved via _make_llm. | ||
| """ | ||
|
|
||
| profile_dir = tmp_path / "profiles" | ||
| profile_dir.mkdir() | ||
| monkeypatch.setattr(llm_profile_store, "_DEFAULT_PROFILE_DIR", profile_dir) | ||
|
|
||
| store = LLMProfileStore(base_dir=profile_dir) | ||
| store.save("fast", _make_llm("fast-model", "fast")) | ||
| store.save("slow", _make_llm("slow-model", "slow")) | ||
| return store | ||
|
|
||
|
|
||
| def _make_conversation() -> LocalConversation: | ||
| return LocalConversation( | ||
| agent=Agent( | ||
| llm=_make_llm("default-model", "test-llm"), | ||
| tools=[], | ||
| ), | ||
| workspace=Path.cwd(), | ||
| ) | ||
|
|
||
|
|
||
| def test_switch_profile(profile_store): | ||
| """switch_profile switches the agent's LLM.""" | ||
| conv = _make_conversation() | ||
| conv.switch_profile("fast") | ||
| assert conv.agent.llm.model == "fast-model" | ||
| conv.switch_profile("slow") | ||
| assert conv.agent.llm.model == "slow-model" | ||
|
|
||
|
|
||
| def test_switch_profile_updates_state(profile_store): | ||
| """switch_profile updates conversation state agent.""" | ||
| conv = _make_conversation() | ||
| conv.switch_profile("fast") | ||
| assert conv.state.agent.llm.model == "fast-model" | ||
|
|
||
|
|
||
| def test_switch_between_profiles(profile_store): | ||
| """Switch fast -> slow -> fast, verify model changes each time.""" | ||
| conv = _make_conversation() | ||
|
|
||
| conv.switch_profile("fast") | ||
| assert conv.agent.llm.model == "fast-model" | ||
|
|
||
| conv.switch_profile("slow") | ||
| assert conv.agent.llm.model == "slow-model" | ||
|
|
||
| conv.switch_profile("fast") | ||
| assert conv.agent.llm.model == "fast-model" | ||
|
|
||
|
|
||
| def test_switch_reuses_registry_entry(profile_store): | ||
| """Switching back to a profile reuses the same registry LLM object.""" | ||
| conv = _make_conversation() | ||
|
|
||
| conv.switch_profile("fast") | ||
| llm_first = conv.llm_registry.get("profile:fast") | ||
|
|
||
| conv.switch_profile("slow") | ||
| conv.switch_profile("fast") | ||
| llm_second = conv.llm_registry.get("profile:fast") | ||
|
|
||
| assert llm_first is llm_second | ||
|
|
||
|
|
||
| def test_switch_nonexistent_raises(profile_store): | ||
| """Switching to a nonexistent profile raises FileNotFoundError.""" | ||
| conv = _make_conversation() | ||
| with pytest.raises(FileNotFoundError): | ||
| conv.switch_profile("nonexistent") | ||
| assert conv.agent.llm.model == "default-model" | ||
| assert conv.state.agent.llm.model == "default-model" | ||
|
|
||
|
|
||
| def test_switch_then_send_message(profile_store): | ||
| """switch_profile followed by send_message doesn't crash on registry collision.""" | ||
| conv = _make_conversation() | ||
| conv.switch_profile("fast") | ||
| # send_message triggers _ensure_agent_ready which re-registers agent LLMs; | ||
| # the switched LLM must not cause a duplicate registration error. | ||
| conv.send_message("hello") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, not about this PR, but this looks curious actually. Are we saying that the agent may have gotten an LLM instance that wasn't registered? Do we know that?