diff --git a/src/agents/models/multi_provider.py b/src/agents/models/multi_provider.py index d075ac9b6..13592d36f 100644 --- a/src/agents/models/multi_provider.py +++ b/src/agents/models/multi_provider.py @@ -21,9 +21,16 @@ def get_mapping(self) -> dict[str, ModelProvider]: """Returns a copy of the current prefix -> ModelProvider mapping.""" return self._mapping.copy() - def set_mapping(self, mapping: dict[str, ModelProvider]): - """Overwrites the current mapping with a new one.""" - self._mapping = mapping + def set_mapping(self, mapping: dict[str, ModelProvider]) -> None: + """Overwrites the current mapping with a new one after validation.""" + if not isinstance(mapping, dict): + TypeError("Mapping must be a dict[str, ModelProvider].") + for k, v in mapping.items(): + if not isinstance(k, str): + raise TypeError(f"Mapping key '{k}' must be a string") + if not isinstance(v, ModelProvider): + raise TypeError(f"Mapping value for '{k}' must be a ModelProvider") + self._mapping = mapping.copy() def get_provider(self, prefix: str) -> ModelProvider | None: """Returns the ModelProvider for the given prefix. @@ -42,12 +49,10 @@ def add_provider(self, prefix: str, provider: ModelProvider): """ self._mapping[prefix] = provider - def remove_provider(self, prefix: str): - """Removes the mapping for the given prefix. - - Args: - prefix: The prefix of the model name e.g. "openai" or "my_prefix". - """ + def remove_provider(self, prefix: str) -> None: + """Removes the mapping for the given prefix.""" + if prefix not in self._mapping: + raise UserError(f"No provider registered for prefix: {prefix}") del self._mapping[prefix] @@ -77,8 +82,8 @@ def __init__( provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided, we will use a default mapping. See the documentation for this class to see the default mapping. - openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use - the default API key. + openai_api_key: The API key to use for the OpenAI provider. If not provided, + we will use the default API key. openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will use the default base URL. openai_client: An optional OpenAI client to use. If not provided, we will create a new @@ -100,18 +105,24 @@ def __init__( self._fallback_providers: dict[str, ModelProvider] = {} def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]: - if model_name is None: + if not model_name: return None, None - elif "/" in model_name: - prefix, model_name = model_name.split("/", 1) - return prefix, model_name - else: - return None, model_name + if "/" in model_name: + prefix, model_part = model_name.split("/", 1) + # normalize empty model part to None so callers can validate consistently + return prefix, model_part if model_part != "" else None + return None, model_name def _create_fallback_provider(self, prefix: str) -> ModelProvider: if prefix == "litellm": - from ..extensions.models.litellm_provider import LitellmProvider - + try: + from ..extensions.models.litellm_provider import LitellmProvider + except ImportError as e: + raise UserError( + "LitellmProvider requires the litellm extension. Install " + "the optional dependency or add a custom provider mapping " + "for the 'litellm' prefix." + ) from e return LitellmProvider() else: raise UserError(f"Unknown prefix: {prefix}") @@ -138,6 +149,9 @@ def get_model(self, model_name: str | None) -> Model: """ prefix, model_name = self._get_prefix_and_model_name(model_name) + # Defensive validation + if model_name is None or model_name == "": + raise UserError("Model name must be provided and non-empty.") if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)): return provider.get_model(model_name) else: