Skip to content

Commit 02c5da8

Browse files
authored
multi_provider: validation and error handling improvements
This PR adds defensive checks and clearer error messages to the multi provider implementation. ### Summary of changes 1. Validate model name in get_model and reject None or empty values. 2. Normalize empty model parts in _get_prefix_and_model_name so inputs like "openai/" are handled as missing model name. 3. Replace raw delete on provider mapping with a check and UserError to avoid KeyError. 4. Catch ImportError for optional litellm provider import and raise a helpful UserError that explains how to resolve the issue. 5. Validate set_mapping inputs and shallow copy the provided mapping to prevent external mutation.
1 parent 73e7843 commit 02c5da8

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

src/agents/models/multi_provider.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@ def get_mapping(self) -> dict[str, ModelProvider]:
2121
"""Returns a copy of the current prefix -> ModelProvider mapping."""
2222
return self._mapping.copy()
2323

24-
def set_mapping(self, mapping: dict[str, ModelProvider]):
25-
"""Overwrites the current mapping with a new one."""
26-
self._mapping = mapping
24+
def set_mapping(self, mapping: dict[str, ModelProvider]) -> None:
25+
"""Overwrites the current mapping with a new one after validation."""
26+
if not isinstance(mapping, dict):
27+
TypeError("Mapping must be a dict[str, ModelProvider].")
28+
for k, v in mapping.items():
29+
if not isinstance(k, str):
30+
raise TypeError(f"Mapping key '{k}' must be a string")
31+
if not isinstance(v, ModelProvider):
32+
raise TypeError(f"Mapping value for '{k}' must be a ModelProvider")
33+
self._mapping = mapping.copy()
2734

2835
def get_provider(self, prefix: str) -> ModelProvider | None:
2936
"""Returns the ModelProvider for the given prefix.
@@ -42,12 +49,10 @@ def add_provider(self, prefix: str, provider: ModelProvider):
4249
"""
4350
self._mapping[prefix] = provider
4451

45-
def remove_provider(self, prefix: str):
46-
"""Removes the mapping for the given prefix.
47-
48-
Args:
49-
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
50-
"""
52+
def remove_provider(self, prefix: str) -> None:
53+
"""Removes the mapping for the given prefix."""
54+
if prefix not in self._mapping:
55+
raise UserError(f"No provider registered for prefix: {prefix}")
5156
del self._mapping[prefix]
5257

5358

@@ -100,18 +105,23 @@ def __init__(
100105
self._fallback_providers: dict[str, ModelProvider] = {}
101106

102107
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
103-
if model_name is None:
108+
if not model_name:
104109
return None, None
105-
elif "/" in model_name:
106-
prefix, model_name = model_name.split("/", 1)
107-
return prefix, model_name
108-
else:
109-
return None, model_name
110+
if "/" in model_name:
111+
prefix, model_part = model_name.split("/", 1)
112+
# normalize empty model part to None so callers can validate consistently
113+
return prefix, model_part if model_part != "" else None
114+
return None, model_name
110115

111116
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
112117
if prefix == "litellm":
113-
from ..extensions.models.litellm_provider import LitellmProvider
114-
118+
try:
119+
from ..extensions.models.litellm_provider import LitellmProvider
120+
except ImportError as e:
121+
raise UserError(
122+
"LitellmProvider requires the litellm extension. Install the optional dependency "
123+
"or add a custom provider mapping for the 'litellm' prefix."
124+
) from e
115125
return LitellmProvider()
116126
else:
117127
raise UserError(f"Unknown prefix: {prefix}")
@@ -138,6 +148,10 @@ def get_model(self, model_name: str | None) -> Model:
138148
"""
139149
prefix, model_name = self._get_prefix_and_model_name(model_name)
140150

151+
# Defensive validation
152+
if model_name is None or model_name == "":
153+
raise UserError("Model name must be provided and non-empty.")
154+
141155
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
142156
return provider.get_model(model_name)
143157
else:

0 commit comments

Comments
 (0)