1
1
from __future__ import annotations
2
2
3
- from collections .abc import AsyncIterator
3
+ from collections .abc import AsyncIterator , Iterator
4
4
from contextlib import asynccontextmanager , contextmanager
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass , field
6
6
from functools import partial
7
- from typing import Any , Literal
7
+ from typing import Any , Callable , Literal
8
8
9
9
import logfire_api
10
+ from opentelemetry ._events import Event , EventLogger , EventLoggerProvider , get_event_logger_provider
11
+ from opentelemetry .trace import Tracer , TracerProvider , get_tracer_provider
10
12
11
13
from ..messages import (
12
14
ModelMessage ,
22
24
)
23
25
from ..settings import ModelSettings
24
26
from ..usage import Usage
25
- from . import ModelRequestParameters , StreamedResponse
27
+ from . import KnownModelName , Model , ModelRequestParameters , StreamedResponse
26
28
from .wrapper import WrapperModel
27
29
28
30
MODEL_SETTING_ATTRIBUTES : tuple [
51
53
class InstrumentedModel (WrapperModel ):
52
54
"""Model which is instrumented with logfire."""
53
55
54
- logfire_instance : logfire_api .Logfire = logfire_api .DEFAULT_LOGFIRE_INSTANCE
56
+ tracer : Tracer = field (repr = False )
57
+ event_logger : EventLogger = field (repr = False )
55
58
56
- def __post_init__ (self ):
57
- self .logfire_instance = self .logfire_instance .with_settings (custom_scope_suffix = 'pydantic_ai' )
59
+ def __init__ (
60
+ self ,
61
+ wrapped : Model | KnownModelName ,
62
+ tracer_provider : TracerProvider | None = None ,
63
+ event_logger_provider : EventLoggerProvider | None = None ,
64
+ ):
65
+ super ().__init__ (wrapped )
66
+ tracer_provider = tracer_provider or get_tracer_provider ()
67
+ event_logger_provider = event_logger_provider or get_event_logger_provider ()
68
+ self .tracer = tracer_provider .get_tracer ('pydantic-ai' )
69
+ self .event_logger = event_logger_provider .get_event_logger ('pydantic-ai' )
70
+
71
+ @classmethod
72
+ def from_logfire (
73
+ cls ,
74
+ wrapped : Model | KnownModelName ,
75
+ logfire_instance : logfire_api .Logfire = logfire_api .DEFAULT_LOGFIRE_INSTANCE ,
76
+ ) -> InstrumentedModel :
77
+ if hasattr (logfire_instance .config , 'get_event_logger_provider' ):
78
+ event_provider = logfire_instance .config .get_event_logger_provider ()
79
+ else :
80
+ event_provider = None
81
+ tracer_provider = logfire_instance .config .get_tracer_provider ()
82
+ return cls (wrapped , tracer_provider , event_provider )
58
83
59
84
async def request (
60
85
self ,
@@ -90,7 +115,7 @@ def _instrument(
90
115
self ,
91
116
messages : list [ModelMessage ],
92
117
model_settings : ModelSettings | None ,
93
- ):
118
+ ) -> Iterator [ Callable [[ ModelResponse , Usage ], None ]] :
94
119
operation = 'chat'
95
120
model_name = self .model_name
96
121
span_name = f'{ operation } { model_name } '
@@ -114,7 +139,7 @@ def _instrument(
114
139
115
140
emit_event = partial (self ._emit_event , system )
116
141
117
- with self .logfire_instance . span (span_name , ** attributes ) as span :
142
+ with self .tracer . start_as_current_span (span_name , attributes = attributes ) as span :
118
143
if span .is_recording ():
119
144
for message in messages :
120
145
if isinstance (message , ModelRequest ):
@@ -157,27 +182,27 @@ def finish(response: ModelResponse, usage: Usage):
157
182
yield finish
158
183
159
184
def _emit_event (self , system : str , event_name : str , body : dict [str , Any ]) -> None :
160
- self .logfire_instance . info ( event_name , ** {'gen_ai.system' : system }, ** body )
185
+ self .event_logger . emit ( Event ( event_name , body = body , attributes = {'gen_ai.system' : system }) )
161
186
162
187
163
188
def _request_part_body (part : ModelRequestPart ) -> tuple [str , dict [str , Any ]]:
164
189
if isinstance (part , SystemPromptPart ):
165
- return 'gen_ai.system.message' , {'content' : part .content }
190
+ return 'gen_ai.system.message' , {'content' : part .content , 'role' : 'system' }
166
191
elif isinstance (part , UserPromptPart ):
167
- return 'gen_ai.user.message' , {'content' : part .content }
192
+ return 'gen_ai.user.message' , {'content' : part .content , 'role' : 'user' }
168
193
elif isinstance (part , ToolReturnPart ):
169
- return 'gen_ai.tool.message' , {'content' : part .content , 'id' : part .tool_call_id }
194
+ return 'gen_ai.tool.message' , {'content' : part .content , 'role' : 'tool' , ' id' : part .tool_call_id }
170
195
elif isinstance (part , RetryPromptPart ):
171
196
if part .tool_name is None :
172
- return 'gen_ai.user.message' , {'content' : part .model_response ()}
197
+ return 'gen_ai.user.message' , {'content' : part .model_response (), 'role' : 'user' }
173
198
else :
174
- return 'gen_ai.tool.message' , {'content' : part .model_response (), 'id' : part .tool_call_id }
199
+ return 'gen_ai.tool.message' , {'content' : part .model_response (), 'role' : 'tool' , ' id' : part .tool_call_id }
175
200
else :
176
201
return '' , {}
177
202
178
203
179
204
def _response_bodies (message : ModelResponse ) -> list [dict [str , Any ]]:
180
- body : dict [str , Any ] = {}
205
+ body : dict [str , Any ] = {'role' : 'assistant' }
181
206
result = [body ]
182
207
for part in message .parts :
183
208
if isinstance (part , ToolCallPart ):
@@ -193,7 +218,7 @@ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
193
218
)
194
219
elif isinstance (part , TextPart ):
195
220
if body .get ('content' ):
196
- body = {}
221
+ body = {'role' : 'assistant' }
197
222
result .append (body )
198
223
body ['content' ] = part .content
199
224
0 commit comments