Skip to content

Commit b63a418

Browse files
authored
Replace model request span with InstrumentedModel (#1012)
1 parent 15c5ef2 commit b63a418

File tree

8 files changed

+165
-108
lines changed

8 files changed

+165
-108
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -286,21 +286,18 @@ async def _stream(
286286
assert not self._did_stream, 'stream() should only be called once per node'
287287

288288
model_settings, model_request_parameters = await self._prepare_request(ctx)
289-
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
290-
async with ctx.deps.model.request_stream(
291-
ctx.state.message_history, model_settings, model_request_parameters
292-
) as streamed_response:
293-
self._did_stream = True
294-
ctx.state.usage.incr(_usage.Usage(), requests=1)
295-
yield streamed_response
296-
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297-
# otherwise usage won't be properly counted:
298-
async for _ in streamed_response:
299-
pass
300-
model_response = streamed_response.get()
301-
request_usage = streamed_response.usage()
302-
span.set_attribute('response', model_response)
303-
span.set_attribute('usage', request_usage)
289+
async with ctx.deps.model.request_stream(
290+
ctx.state.message_history, model_settings, model_request_parameters
291+
) as streamed_response:
292+
self._did_stream = True
293+
ctx.state.usage.incr(_usage.Usage(), requests=1)
294+
yield streamed_response
295+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
296+
# otherwise usage won't be properly counted:
297+
async for _ in streamed_response:
298+
pass
299+
model_response = streamed_response.get()
300+
request_usage = streamed_response.usage()
304301

305302
self._finish_handling(ctx, model_response, request_usage)
306303
assert self._result is not None # this should be set by the previous line
@@ -312,13 +309,10 @@ async def _make_request(
312309
return self._result
313310

314311
model_settings, model_request_parameters = await self._prepare_request(ctx)
315-
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
316-
model_response, request_usage = await ctx.deps.model.request(
317-
ctx.state.message_history, model_settings, model_request_parameters
318-
)
319-
ctx.state.usage.incr(_usage.Usage(), requests=1)
320-
span.set_attribute('response', model_response)
321-
span.set_attribute('usage', request_usage)
312+
model_response, request_usage = await ctx.deps.model.request(
313+
ctx.state.message_history, model_settings, model_request_parameters
314+
)
315+
ctx.state.usage.incr(_usage.Usage(), requests=1)
322316

323317
return self._finish_handling(ctx, model_response, request_usage)
324318

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
result,
2626
usage as _usage,
2727
)
28+
from .models.instrumented import InstrumentedModel
2829
from .result import FinalResult, ResultDataT, StreamedRunResult
2930
from .settings import ModelSettings, merge_model_settings
3031
from .tools import (
@@ -1115,6 +1116,9 @@ def _get_model(self, model: models.Model | models.KnownModelName | None) -> mode
11151116
else:
11161117
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
11171118

1119+
if not isinstance(model_, InstrumentedModel):
1120+
model_ = InstrumentedModel(model_)
1121+
11181122
return model_
11191123

11201124
def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pydantic
1010
import pydantic_core
11+
from opentelemetry._events import Event
1112
from typing_extensions import TypeAlias
1213

1314
from ._utils import now_utc as _now_utc
@@ -33,6 +34,9 @@ class SystemPromptPart:
3334
part_kind: Literal['system-prompt'] = 'system-prompt'
3435
"""Part type identifier, this is available on all parts as a discriminator."""
3536

37+
def otel_event(self) -> Event:
38+
return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
39+
3640

3741
@dataclass
3842
class AudioUrl:
@@ -138,6 +142,14 @@ class UserPromptPart:
138142
part_kind: Literal['user-prompt'] = 'user-prompt'
139143
"""Part type identifier, this is available on all parts as a discriminator."""
140144

145+
def otel_event(self) -> Event:
146+
if isinstance(self.content, str):
147+
content = self.content
148+
else:
149+
# TODO figure out what to record for images and audio
150+
content = [part if isinstance(part, str) else {'kind': part.kind} for part in self.content]
151+
return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
152+
141153

142154
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
143155

@@ -176,6 +188,9 @@ def model_response_object(self) -> dict[str, Any]:
176188
else:
177189
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
178190

191+
def otel_event(self) -> Event:
192+
return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
193+
179194

180195
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
181196

@@ -224,6 +239,14 @@ def model_response(self) -> str:
224239
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
225240
return f'{description}\n\nFix the errors and try again.'
226241

242+
def otel_event(self) -> Event:
243+
if self.tool_name is None:
244+
return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
245+
else:
246+
return Event(
247+
'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
248+
)
249+
227250

228251
ModelRequestPart = Annotated[
229252
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
@@ -329,6 +352,36 @@ class ModelResponse:
329352
kind: Literal['response'] = 'response'
330353
"""Message type identifier, this is available on all parts as a discriminator."""
331354

355+
def otel_events(self) -> list[Event]:
356+
"""Return OpenTelemetry events for the response."""
357+
result: list[Event] = []
358+
359+
def new_event_body():
360+
new_body: dict[str, Any] = {'role': 'assistant'}
361+
ev = Event('gen_ai.assistant.message', body=new_body)
362+
result.append(ev)
363+
return new_body
364+
365+
body = new_event_body()
366+
for part in self.parts:
367+
if isinstance(part, ToolCallPart):
368+
body.setdefault('tool_calls', []).append(
369+
{
370+
'id': part.tool_call_id,
371+
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
372+
'function': {
373+
'name': part.tool_name,
374+
'arguments': part.args,
375+
},
376+
}
377+
)
378+
elif isinstance(part, TextPart):
379+
if body.get('content'):
380+
body = new_event_body()
381+
body['content'] = part.content
382+
383+
return result
384+
332385

333386
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
334387
"""Any message sent to or returned by a model."""

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 57 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,21 @@
11
from __future__ import annotations
22

33
import json
4-
from collections.abc import AsyncIterator, Iterator
4+
from collections.abc import AsyncIterator, Iterator, Mapping
55
from contextlib import asynccontextmanager, contextmanager
66
from dataclasses import dataclass, field
7-
from functools import partial
87
from typing import Any, Callable, Literal
98

109
import logfire_api
1110
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
12-
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
11+
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
1312
from opentelemetry.util.types import AttributeValue
13+
from pydantic import TypeAdapter
1414

1515
from ..messages import (
1616
ModelMessage,
1717
ModelRequest,
18-
ModelRequestPart,
1918
ModelResponse,
20-
RetryPromptPart,
21-
SystemPromptPart,
22-
TextPart,
23-
ToolCallPart,
24-
ToolReturnPart,
25-
UserPromptPart,
2619
)
2720
from ..settings import ModelSettings
2821
from ..usage import Usage
@@ -48,6 +41,8 @@
4841
'frequency_penalty',
4942
)
5043

44+
ANY_ADAPTER = TypeAdapter[Any](Any)
45+
5146

5247
@dataclass
5348
class InstrumentedModel(WrapperModel):
@@ -115,7 +110,7 @@ async def request_stream(
115110
finish(response_stream.get(), response_stream.usage())
116111

117112
@contextmanager
118-
def _instrument( # noqa: C901
113+
def _instrument(
119114
self,
120115
messages: list[ModelMessage],
121116
model_settings: ModelSettings | None,
@@ -141,35 +136,24 @@ def _instrument( # noqa: C901
141136
if isinstance(value := model_settings.get(key), (float, int)):
142137
attributes[f'gen_ai.request.{key}'] = value
143138

144-
events_list = []
145-
emit_event = partial(self._emit_event, system, events_list)
146-
147139
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
148-
if span.is_recording():
149-
for message in messages:
150-
if isinstance(message, ModelRequest):
151-
for part in message.parts:
152-
event_name, body = _request_part_body(part)
153-
if event_name:
154-
emit_event(event_name, body)
155-
elif isinstance(message, ModelResponse):
156-
for body in _response_bodies(message):
157-
emit_event('gen_ai.assistant.message', body)
158140

159141
def finish(response: ModelResponse, usage: Usage):
160142
if not span.is_recording():
161143
return
162144

163-
for response_body in _response_bodies(response):
164-
if response_body:
165-
emit_event(
145+
events = self.messages_to_otel_events(messages)
146+
for event in self.messages_to_otel_events([response]):
147+
events.append(
148+
Event(
166149
'gen_ai.choice',
167-
{
150+
body={
168151
# TODO finish_reason
169152
'index': 0,
170-
'message': response_body,
153+
'message': event.body,
171154
},
172155
)
156+
)
173157
span.set_attributes(
174158
{
175159
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
@@ -178,67 +162,56 @@ def finish(response: ModelResponse, usage: Usage):
178162
**usage.opentelemetry_attributes(),
179163
}
180164
)
181-
if events_list:
182-
attr_name = 'events'
183-
span.set_attributes(
184-
{
185-
attr_name: json.dumps(events_list),
186-
'logfire.json_schema': json.dumps(
187-
{
188-
'type': 'object',
189-
'properties': {attr_name: {'type': 'array'}},
190-
}
191-
),
192-
}
193-
)
165+
self._emit_events(system, span, events)
194166

195167
yield finish
196168

197-
def _emit_event(
198-
self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
199-
) -> None:
200-
attributes = {'gen_ai.system': system}
169+
def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
170+
for event in events:
171+
event.attributes = {'gen_ai.system': system, **(event.attributes or {})}
201172
if self.event_mode == 'logs':
202-
self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
203-
else:
204-
events_list.append({'event.name': event_name, **body, **attributes})
205-
206-
207-
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
208-
if isinstance(part, SystemPromptPart):
209-
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
210-
elif isinstance(part, UserPromptPart):
211-
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
212-
elif isinstance(part, ToolReturnPart):
213-
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
214-
elif isinstance(part, RetryPromptPart):
215-
if part.tool_name is None:
216-
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
173+
for event in events:
174+
self.event_logger.emit(event)
217175
else:
218-
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
219-
else:
220-
return '', {}
221-
222-
223-
def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
224-
body: dict[str, Any] = {'role': 'assistant'}
225-
result = [body]
226-
for part in message.parts:
227-
if isinstance(part, ToolCallPart):
228-
body.setdefault('tool_calls', []).append(
176+
attr_name = 'events'
177+
span.set_attributes(
229178
{
230-
'id': part.tool_call_id,
231-
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
232-
'function': {
233-
'name': part.tool_name,
234-
'arguments': part.args,
235-
},
179+
attr_name: json.dumps([self.event_to_dict(event) for event in events]),
180+
'logfire.json_schema': json.dumps(
181+
{
182+
'type': 'object',
183+
'properties': {attr_name: {'type': 'array'}},
184+
}
185+
),
236186
}
237187
)
238-
elif isinstance(part, TextPart):
239-
if body.get('content'):
240-
body = {'role': 'assistant'}
241-
result.append(body)
242-
body['content'] = part.content
243188

244-
return result
189+
@staticmethod
190+
def event_to_dict(event: Event) -> dict[str, Any]:
191+
if not event.body:
192+
body = {}
193+
elif isinstance(event.body, Mapping):
194+
body = event.body # type: ignore
195+
else:
196+
body = {'body': event.body}
197+
return {**body, **(event.attributes or {})}
198+
199+
@staticmethod
200+
def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
201+
result: list[Event] = []
202+
for message in messages:
203+
if isinstance(message, ModelRequest):
204+
for part in message.parts:
205+
if hasattr(part, 'otel_event'):
206+
result.append(part.otel_event())
207+
elif isinstance(message, ModelResponse):
208+
result.extend(message.otel_events())
209+
for event in result:
210+
try:
211+
event.body = ANY_ADAPTER.dump_python(event.body, mode='json')
212+
except Exception:
213+
try:
214+
event.body = str(event.body)
215+
except Exception:
216+
event.body = 'Unable to serialize event body'
217+
return result

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"pydantic>=2.10",
4040
"pydantic-graph==0.0.30",
4141
"exceptiongroup; python_version < '3.11'",
42+
"opentelemetry-api>=1.28.0",
4243
]
4344

4445
[project.optional-dependencies]

0 commit comments

Comments
 (0)