Skip to content

Commit f96849c

Browse files
authored
Fix instrumentation of FallbackModel (#1076)
1 parent bca218e commit f96849c

File tree

4 files changed

+149
-29
lines changed

4 files changed

+149
-29
lines changed

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ async def request(
6161

6262
for model in self.models:
6363
try:
64-
return await model.request(messages, model_settings, model_request_parameters)
64+
response, usage = await model.request(messages, model_settings, model_request_parameters)
65+
response.model_used = model # type: ignore
66+
return response, usage
6567
except Exception as exc:
6668
if self._fallback_on(exc):
6769
exceptions.append(exc)

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def __init__(
8888
self.event_mode = event_mode
8989

9090

91+
GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
92+
GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
93+
94+
9195
@dataclass
9296
class InstrumentedModel(WrapperModel):
9397
"""Model which is instrumented with OpenTelemetry."""
@@ -138,27 +142,14 @@ def _instrument(
138142
model_settings: ModelSettings | None,
139143
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
140144
operation = 'chat'
141-
model_name = self.model_name
142-
span_name = f'{operation} {model_name}'
143-
system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
144-
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
145+
span_name = f'{operation} {self.model_name}'
145146
# TODO Missing attributes:
146147
# - error.type: unclear if we should do something here or just always rely on span exceptions
147148
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
148149
attributes: dict[str, AttributeValue] = {
149150
'gen_ai.operation.name': operation,
150-
'gen_ai.system': system,
151-
'gen_ai.request.model': model_name,
151+
**self.model_attributes(self.wrapped),
152152
}
153-
if base_url := self.wrapped.base_url:
154-
try:
155-
parsed = urlparse(base_url)
156-
if parsed.hostname:
157-
attributes['server.address'] = parsed.hostname
158-
if parsed.port:
159-
attributes['server.port'] = parsed.port
160-
except Exception: # pragma: no cover
161-
pass
162153

163154
if model_settings:
164155
for key in MODEL_SETTING_ATTRIBUTES:
@@ -183,21 +174,26 @@ def finish(response: ModelResponse, usage: Usage):
183174
},
184175
)
185176
)
186-
span.set_attributes(
187-
{
188-
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
189-
# https://github.com/pydantic/pydantic-ai/issues/886
190-
'gen_ai.response.model': response.model_name or model_name,
191-
**usage.opentelemetry_attributes(),
177+
new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore
178+
if model_used := getattr(response, 'model_used', None):
179+
# FallbackModel sets model_used on the response so that we can report the attributes
180+
# of the model that was actually used.
181+
new_attributes.update(self.model_attributes(model_used))
182+
attributes.update(new_attributes)
183+
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
184+
new_attributes['gen_ai.response.model'] = response.model_name or request_model
185+
span.set_attributes(new_attributes)
186+
span.update_name(f'{operation} {request_model}')
187+
for event in events:
188+
event.attributes = {
189+
GEN_AI_SYSTEM_ATTRIBUTE: attributes[GEN_AI_SYSTEM_ATTRIBUTE],
190+
**(event.attributes or {}),
192191
}
193-
)
194-
self._emit_events(system, span, events)
192+
self._emit_events(span, events)
195193

196194
yield finish
197195

198-
def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
199-
for event in events:
200-
event.attributes = {'gen_ai.system': system, **(event.attributes or {})}
196+
def _emit_events(self, span: Span, events: list[Event]) -> None:
201197
if self.options.event_mode == 'logs':
202198
for event in events:
203199
self.options.event_logger.emit(event)
@@ -215,6 +211,27 @@ def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
215211
}
216212
)
217213

214+
@staticmethod
215+
def model_attributes(model: Model):
216+
system = getattr(model, 'system', '') or model.__class__.__name__.removesuffix('Model').lower()
217+
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
218+
attributes: dict[str, AttributeValue] = {
219+
GEN_AI_SYSTEM_ATTRIBUTE: system,
220+
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
221+
}
222+
if base_url := model.base_url:
223+
try:
224+
parsed = urlparse(base_url)
225+
except Exception: # pragma: no cover
226+
pass
227+
else:
228+
if parsed.hostname:
229+
attributes['server.address'] = parsed.hostname
230+
if parsed.port:
231+
attributes['server.port'] = parsed.port
232+
233+
return attributes
234+
218235
@staticmethod
219236
def event_to_dict(event: Event) -> dict[str, Any]:
220237
if not event.body:

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def opentelemetry_attributes(self) -> dict[str, int]:
6464
}
6565
for key, value in (self.details or {}).items():
6666
result[f'gen_ai.usage.details.{key}'] = value
67-
return {k: v for k, v in result.items() if v is not None}
67+
return {k: v for k, v in result.items() if v}
6868

6969

7070
@dataclass

tests/models/test_fallback.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1+
from __future__ import annotations
2+
13
import sys
24
from collections.abc import AsyncIterator
35
from datetime import timezone
46

57
import pytest
8+
from dirty_equals import IsJson
69
from inline_snapshot import snapshot
710

811
from pydantic_ai import Agent, ModelHTTPError
912
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart
1013
from pydantic_ai.models.fallback import FallbackModel
1114
from pydantic_ai.models.function import AgentInfo, FunctionModel
1215

13-
from ..conftest import IsNow
16+
from ..conftest import IsNow, try_import
1417

1518
if sys.version_info < (3, 11):
1619
from exceptiongroup import ExceptionGroup as ExceptionGroup
1720
else:
1821
ExceptionGroup = ExceptionGroup
1922

23+
with try_import() as logfire_imports_successful:
24+
from logfire.testing import CaptureLogfire
25+
26+
2027
pytestmark = pytest.mark.anyio
2128

2229

@@ -86,6 +93,100 @@ def test_first_failed() -> None:
8693
)
8794

8895

96+
@pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed')
97+
def test_first_failed_instrumented(capfire: CaptureLogfire) -> None:
98+
fallback_model = FallbackModel(failure_model, success_model)
99+
agent = Agent(model=fallback_model, instrument=True)
100+
result = agent.run_sync('hello')
101+
assert result.data == snapshot('success')
102+
assert result.all_messages() == snapshot(
103+
[
104+
ModelRequest(
105+
parts=[
106+
UserPromptPart(
107+
content='hello',
108+
timestamp=IsNow(tz=timezone.utc),
109+
)
110+
]
111+
),
112+
ModelResponse(
113+
parts=[TextPart(content='success')],
114+
model_name='function:success_response:',
115+
timestamp=IsNow(tz=timezone.utc),
116+
),
117+
]
118+
)
119+
assert capfire.exporter.exported_spans_as_dict() == snapshot(
120+
[
121+
{
122+
'name': 'preparing model request params',
123+
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
124+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
125+
'start_time': 2000000000,
126+
'end_time': 3000000000,
127+
'attributes': {
128+
'run_step': 1,
129+
'logfire.span_type': 'span',
130+
'logfire.msg': 'preparing model request params',
131+
},
132+
},
133+
{
134+
'name': 'chat function:success_response:',
135+
'context': {'trace_id': 1, 'span_id': 5, 'is_remote': False},
136+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
137+
'start_time': 4000000000,
138+
'end_time': 5000000000,
139+
'attributes': {
140+
'gen_ai.operation.name': 'chat',
141+
'logfire.span_type': 'span',
142+
'logfire.msg': 'chat FallBackModel[function:failure_response:, function:success_response:]',
143+
'gen_ai.usage.input_tokens': 51,
144+
'gen_ai.usage.output_tokens': 1,
145+
'gen_ai.system': 'function',
146+
'gen_ai.request.model': 'function:success_response:',
147+
'gen_ai.response.model': 'function:success_response:',
148+
'events': IsJson(
149+
[
150+
{
151+
'content': 'hello',
152+
'role': 'user',
153+
'gen_ai.system': 'function',
154+
'gen_ai.message.index': 0,
155+
'event.name': 'gen_ai.user.message',
156+
},
157+
{
158+
'index': 0,
159+
'message': {'role': 'assistant', 'content': 'success'},
160+
'gen_ai.system': 'function',
161+
'event.name': 'gen_ai.choice',
162+
},
163+
]
164+
),
165+
'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}}}',
166+
},
167+
},
168+
{
169+
'name': 'agent run',
170+
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
171+
'parent': None,
172+
'start_time': 1000000000,
173+
'end_time': 6000000000,
174+
'attributes': {
175+
'model_name': 'FallBackModel[function:failure_response:, function:success_response:]',
176+
'agent_name': 'agent',
177+
'logfire.msg': 'agent run',
178+
'logfire.span_type': 'span',
179+
'gen_ai.usage.input_tokens': 51,
180+
'gen_ai.usage.output_tokens': 1,
181+
'all_messages_events': '[{"content": "hello", "role": "user", "gen_ai.message.index": 0, "event.name": "gen_ai.user.message"}, {"role": "assistant", "content": "success", "gen_ai.message.index": 1, "event.name": "gen_ai.assistant.message"}]',
182+
'final_result': 'success',
183+
'logfire.json_schema': '{"type": "object", "properties": {"all_messages_events": {"type": "array"}, "final_result": {"type": "object"}}}',
184+
},
185+
},
186+
]
187+
)
188+
189+
89190
def test_all_failed() -> None:
90191
fallback_model = FallbackModel(failure_model, failure_model)
91192
agent = Agent(model=fallback_model)

0 commit comments

Comments
 (0)