Skip to content

Commit 4a472bd

Browse files
authored
Use raw OTel and actual event loggers in InstrumentedModel (#945)
1 parent dfc919c commit 4a472bd

File tree

4 files changed

+291
-322
lines changed

4 files changed

+291
-322
lines changed

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterator
3+
from collections.abc import AsyncIterator, Iterator
44
from contextlib import asynccontextmanager, contextmanager
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from functools import partial
7-
from typing import Any, Literal
7+
from typing import Any, Callable, Literal
88

99
import logfire_api
10+
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11+
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
1012

1113
from ..messages import (
1214
ModelMessage,
@@ -22,7 +24,7 @@
2224
)
2325
from ..settings import ModelSettings
2426
from ..usage import Usage
25-
from . import ModelRequestParameters, StreamedResponse
27+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
2628
from .wrapper import WrapperModel
2729

2830
MODEL_SETTING_ATTRIBUTES: tuple[
@@ -51,10 +53,33 @@
5153
class InstrumentedModel(WrapperModel):
5254
"""Model which is instrumented with logfire."""
5355

54-
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE
56+
tracer: Tracer = field(repr=False)
57+
event_logger: EventLogger = field(repr=False)
5558

56-
def __post_init__(self):
57-
self.logfire_instance = self.logfire_instance.with_settings(custom_scope_suffix='pydantic_ai')
59+
def __init__(
60+
self,
61+
wrapped: Model | KnownModelName,
62+
tracer_provider: TracerProvider | None = None,
63+
event_logger_provider: EventLoggerProvider | None = None,
64+
):
65+
super().__init__(wrapped)
66+
tracer_provider = tracer_provider or get_tracer_provider()
67+
event_logger_provider = event_logger_provider or get_event_logger_provider()
68+
self.tracer = tracer_provider.get_tracer('pydantic-ai')
69+
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
70+
71+
@classmethod
72+
def from_logfire(
73+
cls,
74+
wrapped: Model | KnownModelName,
75+
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
76+
) -> InstrumentedModel:
77+
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
78+
event_provider = logfire_instance.config.get_event_logger_provider()
79+
else:
80+
event_provider = None
81+
tracer_provider = logfire_instance.config.get_tracer_provider()
82+
return cls(wrapped, tracer_provider, event_provider)
5883

5984
async def request(
6085
self,
@@ -90,7 +115,7 @@ def _instrument(
90115
self,
91116
messages: list[ModelMessage],
92117
model_settings: ModelSettings | None,
93-
):
118+
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
94119
operation = 'chat'
95120
model_name = self.model_name
96121
span_name = f'{operation} {model_name}'
@@ -114,7 +139,7 @@ def _instrument(
114139

115140
emit_event = partial(self._emit_event, system)
116141

117-
with self.logfire_instance.span(span_name, **attributes) as span:
142+
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
118143
if span.is_recording():
119144
for message in messages:
120145
if isinstance(message, ModelRequest):
@@ -157,27 +182,27 @@ def finish(response: ModelResponse, usage: Usage):
157182
yield finish
158183

159184
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
160-
self.logfire_instance.info(event_name, **{'gen_ai.system': system}, **body)
185+
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
161186

162187

163188
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
164189
if isinstance(part, SystemPromptPart):
165-
return 'gen_ai.system.message', {'content': part.content}
190+
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
166191
elif isinstance(part, UserPromptPart):
167-
return 'gen_ai.user.message', {'content': part.content}
192+
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
168193
elif isinstance(part, ToolReturnPart):
169-
return 'gen_ai.tool.message', {'content': part.content, 'id': part.tool_call_id}
194+
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
170195
elif isinstance(part, RetryPromptPart):
171196
if part.tool_name is None:
172-
return 'gen_ai.user.message', {'content': part.model_response()}
197+
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
173198
else:
174-
return 'gen_ai.tool.message', {'content': part.model_response(), 'id': part.tool_call_id}
199+
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
175200
else:
176201
return '', {}
177202

178203

179204
def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
180-
body: dict[str, Any] = {}
205+
body: dict[str, Any] = {'role': 'assistant'}
181206
result = [body]
182207
for part in message.parts:
183208
if isinstance(part, ToolCallPart):
@@ -193,7 +218,7 @@ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
193218
)
194219
elif isinstance(part, TextPart):
195220
if body.get('content'):
196-
body = {}
221+
body = {'role': 'assistant'}
197222
result.append(body)
198223
body['content'] = part.content
199224

pydantic_ai_slim/pydantic_ai/models/wrapper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
from ..messages import ModelMessage, ModelResponse
99
from ..settings import ModelSettings
1010
from ..usage import Usage
11-
from . import Model, ModelRequestParameters, StreamedResponse
11+
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1212

1313

14-
@dataclass
14+
@dataclass(init=False)
1515
class WrapperModel(Model):
1616
"""Model which wraps another model."""
1717

1818
wrapped: Model
1919

20+
def __init__(self, wrapped: Model | KnownModelName):
21+
self.wrapped = infer_model(wrapped)
22+
2023
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
2124
return await self.wrapped.request(*args, **kwargs)
2225

0 commit comments

Comments
 (0)