Skip to content

Commit 42dac78

Browse files
authored
InstrumentedModel and FallbackModel fixes (#1121)
1 parent 416a0d1 commit 42dac78

File tree

18 files changed

+57
-62
lines changed

18 files changed

+57
-62
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,14 @@ def model_name(self) -> str:
262262

263263
@property
264264
@abstractmethod
265-
def system(self) -> str | None:
266-
"""The system / model provider, ex: openai."""
265+
def system(self) -> str:
266+
"""The system / model provider, ex: openai.
267+
268+
Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
269+
so should use well-known values listed in
270+
https://opentelemetry.io/docs/specs/semconv/attributes-registry/gen-ai/#gen-ai-system
271+
when applicable.
272+
"""
267273
raise NotImplementedError()
268274

269275
@property

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class AnthropicModel(Model):
115115
client: AsyncAnthropic = field(repr=False)
116116

117117
_model_name: AnthropicModelName = field(repr=False)
118-
_system: str | None = field(default='anthropic', repr=False)
118+
_system: str = field(default='anthropic', repr=False)
119119

120120
def __init__(
121121
self,
@@ -183,7 +183,7 @@ def model_name(self) -> AnthropicModelName:
183183
return self._model_name
184184

185185
@property
186-
def system(self) -> str | None:
186+
def system(self) -> str:
187187
"""The system / model provider."""
188188
return self._system
189189

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ class BedrockConverseModel(Model):
119119
client: BedrockRuntimeClient
120120

121121
_model_name: BedrockModelName = field(repr=False)
122-
_system: str | None = field(default='bedrock', repr=False)
122+
_system: str = field(default='bedrock', repr=False)
123123

124124
@property
125125
def model_name(self) -> str:
126126
"""The model name."""
127127
return self._model_name
128128

129129
@property
130-
def system(self) -> str | None:
130+
def system(self) -> str:
131131
"""The system / model provider, ex: openai."""
132132
return self._system
133133

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class CohereModel(Model):
9898
client: AsyncClientV2 = field(repr=False)
9999

100100
_model_name: CohereModelName = field(repr=False)
101-
_system: str | None = field(default='cohere', repr=False)
101+
_system: str = field(default='cohere', repr=False)
102102

103103
def __init__(
104104
self,
@@ -148,7 +148,7 @@ def model_name(self) -> CohereModelName:
148148
return self._model_name
149149

150150
@property
151-
def system(self) -> str | None:
151+
def system(self) -> str:
152152
"""The system / model provider."""
153153
return self._system
154154

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import AsyncIterator
4-
from contextlib import AsyncExitStack, asynccontextmanager
4+
from contextlib import AsyncExitStack, asynccontextmanager, suppress
55
from dataclasses import dataclass, field
66
from typing import TYPE_CHECKING, Callable
77

8+
from opentelemetry.trace import get_current_span
9+
10+
from pydantic_ai.models.instrumented import InstrumentedModel
11+
812
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
913
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1014

@@ -40,7 +44,6 @@ def __init__(
4044
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
4145
"""
4246
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
43-
self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'
4447

4548
if isinstance(fallback_on, tuple):
4649
self._fallback_on = _default_fallback_condition_factory(fallback_on)
@@ -62,13 +65,19 @@ async def request(
6265
for model in self.models:
6366
try:
6467
response, usage = await model.request(messages, model_settings, model_request_parameters)
65-
response.model_used = model # type: ignore
66-
return response, usage
6768
except Exception as exc:
6869
if self._fallback_on(exc):
6970
exceptions.append(exc)
7071
continue
7172
raise exc
73+
else:
74+
with suppress(Exception):
75+
span = get_current_span()
76+
if span.is_recording():
77+
attributes = getattr(span, 'attributes', {})
78+
if attributes.get('gen_ai.request.model') == self.model_name:
79+
span.set_attributes(InstrumentedModel.model_attributes(model))
80+
return response, usage
7281

7382
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
7483

@@ -101,12 +110,11 @@ async def request_stream(
101110
@property
102111
def model_name(self) -> str:
103112
"""The model name."""
104-
return self._model_name
113+
return f'fallback:{",".join(model.model_name for model in self.models)}'
105114

106115
@property
107-
def system(self) -> str | None:
108-
"""The system / model provider, n/a for fallback models."""
109-
return None
116+
def system(self) -> str:
117+
return f'fallback:{",".join(model.system for model in self.models)}'
110118

111119
@property
112120
def base_url(self) -> str | None:

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class FunctionModel(Model):
4545
stream_function: StreamFunctionDef | None = None
4646

4747
_model_name: str = field(repr=False)
48-
_system: str | None = field(default=None, repr=False)
48+
_system: str = field(default='function', repr=False)
4949

5050
@overload
5151
def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
@@ -140,7 +140,7 @@ def model_name(self) -> str:
140140
return self._model_name
141141

142142
@property
143-
def system(self) -> str | None:
143+
def system(self) -> str:
144144
"""The system / model provider."""
145145
return self._system
146146

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class GeminiModel(Model):
9191
_provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
9292
_auth: AuthProtocol | None = field(repr=False)
9393
_url: str | None = field(repr=False)
94-
_system: str | None = field(default='google-gla', repr=False)
94+
_system: str = field(default='gemini', repr=False)
9595

9696
@overload
9797
def __init__(
@@ -197,7 +197,7 @@ def model_name(self) -> GeminiModelName:
197197
return self._model_name
198198

199199
@property
200-
def system(self) -> str | None:
200+
def system(self) -> str:
201201
"""The system / model provider."""
202202
return self._system
203203

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class GroqModel(Model):
8888
client: AsyncGroq = field(repr=False)
8989

9090
_model_name: GroqModelName = field(repr=False)
91-
_system: str | None = field(default='groq', repr=False)
91+
_system: str = field(default='groq', repr=False)
9292

9393
@overload
9494
def __init__(
@@ -186,7 +186,7 @@ def model_name(self) -> GroqModelName:
186186
return self._model_name
187187

188188
@property
189-
def system(self) -> str | None:
189+
def system(self) -> str:
190190
"""The system / model provider."""
191191
return self._system
192192

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,7 @@ def finish(response: ModelResponse, usage: Usage):
175175
)
176176
)
177177
new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore
178-
if model_used := getattr(response, 'model_used', None):
179-
# FallbackModel sets model_used on the response so that we can report the attributes
180-
# of the model that was actually used.
181-
new_attributes.update(self.model_attributes(model_used))
182-
attributes.update(new_attributes)
178+
attributes.update(getattr(span, 'attributes', {}))
183179
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
184180
new_attributes['gen_ai.response.model'] = response.model_name or request_model
185181
span.set_attributes(new_attributes)
@@ -213,10 +209,8 @@ def _emit_events(self, span: Span, events: list[Event]) -> None:
213209

214210
@staticmethod
215211
def model_attributes(model: Model):
216-
system = getattr(model, 'system', '') or model.__class__.__name__.removesuffix('Model').lower()
217-
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
218212
attributes: dict[str, AttributeValue] = {
219-
GEN_AI_SYSTEM_ATTRIBUTE: system,
213+
GEN_AI_SYSTEM_ATTRIBUTE: model.system,
220214
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
221215
}
222216
if base_url := model.base_url:

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class MistralModel(Model):
110110
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
111111

112112
_model_name: MistralModelName = field(repr=False)
113-
_system: str | None = field(default='mistral', repr=False)
113+
_system: str = field(default='mistral_ai', repr=False)
114114

115115
def __init__(
116116
self,
@@ -179,7 +179,7 @@ def model_name(self) -> MistralModelName:
179179
return self._model_name
180180

181181
@property
182-
def system(self) -> str | None:
182+
def system(self) -> str:
183183
"""The system / model provider."""
184184
return self._system
185185

0 commit comments

Comments
 (0)