Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions src/agents/models/multi_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@ def get_mapping(self) -> dict[str, ModelProvider]:
"""Returns a copy of the current prefix -> ModelProvider mapping."""
return self._mapping.copy()

def set_mapping(self, mapping: dict[str, ModelProvider]):
"""Overwrites the current mapping with a new one."""
self._mapping = mapping
def set_mapping(self, mapping: dict[str, ModelProvider]) -> None:
"""Overwrites the current mapping with a new one after validation."""
if not isinstance(mapping, dict):
TypeError("Mapping must be a dict[str, ModelProvider].")
for k, v in mapping.items():
if not isinstance(k, str):
raise TypeError(f"Mapping key '{k}' must be a string")
if not isinstance(v, ModelProvider):
raise TypeError(f"Mapping value for '{k}' must be a ModelProvider")
self._mapping = mapping.copy()

def get_provider(self, prefix: str) -> ModelProvider | None:
"""Returns the ModelProvider for the given prefix.
Expand All @@ -42,12 +49,10 @@ def add_provider(self, prefix: str, provider: ModelProvider):
"""
self._mapping[prefix] = provider

def remove_provider(self, prefix: str):
"""Removes the mapping for the given prefix.

Args:
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
"""
def remove_provider(self, prefix: str) -> None:
"""Removes the mapping for the given prefix."""
if prefix not in self._mapping:
raise UserError(f"No provider registered for prefix: {prefix}")
del self._mapping[prefix]


Expand Down Expand Up @@ -77,8 +82,8 @@ def __init__(
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
we will use a default mapping. See the documentation for this class to see the
default mapping.
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
the default API key.
openai_api_key: The API key to use for the OpenAI provider. If not provided,
we will use the default API key.
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
use the default base URL.
openai_client: An optional OpenAI client to use. If not provided, we will create a new
Expand All @@ -100,18 +105,24 @@ def __init__(
self._fallback_providers: dict[str, ModelProvider] = {}

def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
if model_name is None:
if not model_name:
return None, None
elif "/" in model_name:
prefix, model_name = model_name.split("/", 1)
return prefix, model_name
else:
return None, model_name
if "/" in model_name:
prefix, model_part = model_name.split("/", 1)
# normalize empty model part to None so callers can validate consistently
return prefix, model_part if model_part != "" else None
return None, model_name

def _create_fallback_provider(self, prefix: str) -> ModelProvider:
if prefix == "litellm":
from ..extensions.models.litellm_provider import LitellmProvider

try:
from ..extensions.models.litellm_provider import LitellmProvider
except ImportError as e:
raise UserError(
"LitellmProvider requires the litellm extension. Install "
"the optional dependency or add a custom provider mapping "
"for the 'litellm' prefix."
) from e
return LitellmProvider()
else:
raise UserError(f"Unknown prefix: {prefix}")
Expand All @@ -138,6 +149,9 @@ def get_model(self, model_name: str | None) -> Model:
"""
prefix, model_name = self._get_prefix_and_model_name(model_name)

# Defensive validation
if model_name is None or model_name == "":
raise UserError("Model name must be provided and non-empty.")
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
return provider.get_model(model_name)
else:
Expand Down