1
1
from __future__ import annotations
2
2
3
3
import json
4
- from collections .abc import AsyncIterator , Iterator
4
+ from collections .abc import AsyncIterator , Iterator , Mapping
5
5
from contextlib import asynccontextmanager , contextmanager
6
6
from dataclasses import dataclass , field
7
- from functools import partial
8
7
from typing import Any , Callable , Literal
9
8
10
9
import logfire_api
11
10
from opentelemetry ._events import Event , EventLogger , EventLoggerProvider , get_event_logger_provider
12
- from opentelemetry .trace import Tracer , TracerProvider , get_tracer_provider
11
+ from opentelemetry .trace import Span , Tracer , TracerProvider , get_tracer_provider
13
12
from opentelemetry .util .types import AttributeValue
13
+ from pydantic import TypeAdapter
14
14
15
15
from ..messages import (
16
16
ModelMessage ,
17
17
ModelRequest ,
18
- ModelRequestPart ,
19
18
ModelResponse ,
20
- RetryPromptPart ,
21
- SystemPromptPart ,
22
- TextPart ,
23
- ToolCallPart ,
24
- ToolReturnPart ,
25
- UserPromptPart ,
26
19
)
27
20
from ..settings import ModelSettings
28
21
from ..usage import Usage
48
41
'frequency_penalty' ,
49
42
)
50
43
44
+ ANY_ADAPTER = TypeAdapter [Any ](Any )
45
+
51
46
52
47
@dataclass
53
48
class InstrumentedModel (WrapperModel ):
@@ -115,7 +110,7 @@ async def request_stream(
115
110
finish (response_stream .get (), response_stream .usage ())
116
111
117
112
@contextmanager
118
- def _instrument ( # noqa: C901
113
+ def _instrument (
119
114
self ,
120
115
messages : list [ModelMessage ],
121
116
model_settings : ModelSettings | None ,
@@ -141,35 +136,24 @@ def _instrument( # noqa: C901
141
136
if isinstance (value := model_settings .get (key ), (float , int )):
142
137
attributes [f'gen_ai.request.{ key } ' ] = value
143
138
144
- events_list = []
145
- emit_event = partial (self ._emit_event , system , events_list )
146
-
147
139
with self .tracer .start_as_current_span (span_name , attributes = attributes ) as span :
148
- if span .is_recording ():
149
- for message in messages :
150
- if isinstance (message , ModelRequest ):
151
- for part in message .parts :
152
- event_name , body = _request_part_body (part )
153
- if event_name :
154
- emit_event (event_name , body )
155
- elif isinstance (message , ModelResponse ):
156
- for body in _response_bodies (message ):
157
- emit_event ('gen_ai.assistant.message' , body )
158
140
159
141
def finish (response : ModelResponse , usage : Usage ):
160
142
if not span .is_recording ():
161
143
return
162
144
163
- for response_body in _response_bodies (response ):
164
- if response_body :
165
- emit_event (
145
+ events = self .messages_to_otel_events (messages )
146
+ for event in self .messages_to_otel_events ([response ]):
147
+ events .append (
148
+ Event (
166
149
'gen_ai.choice' ,
167
- {
150
+ body = {
168
151
# TODO finish_reason
169
152
'index' : 0 ,
170
- 'message' : response_body ,
153
+ 'message' : event . body ,
171
154
},
172
155
)
156
+ )
173
157
span .set_attributes (
174
158
{
175
159
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
@@ -178,67 +162,56 @@ def finish(response: ModelResponse, usage: Usage):
178
162
** usage .opentelemetry_attributes (),
179
163
}
180
164
)
181
- if events_list :
182
- attr_name = 'events'
183
- span .set_attributes (
184
- {
185
- attr_name : json .dumps (events_list ),
186
- 'logfire.json_schema' : json .dumps (
187
- {
188
- 'type' : 'object' ,
189
- 'properties' : {attr_name : {'type' : 'array' }},
190
- }
191
- ),
192
- }
193
- )
165
+ self ._emit_events (system , span , events )
194
166
195
167
yield finish
196
168
197
- def _emit_event (
198
- self , system : str , events_list : list [dict [str , Any ]], event_name : str , body : dict [str , Any ]
199
- ) -> None :
200
- attributes = {'gen_ai.system' : system }
169
+ def _emit_events (self , system : str , span : Span , events : list [Event ]) -> None :
170
+ for event in events :
171
+ event .attributes = {'gen_ai.system' : system , ** (event .attributes or {})}
201
172
if self .event_mode == 'logs' :
202
- self .event_logger .emit (Event (event_name , body = body , attributes = attributes ))
203
- else :
204
- events_list .append ({'event.name' : event_name , ** body , ** attributes })
205
-
206
-
207
- def _request_part_body (part : ModelRequestPart ) -> tuple [str , dict [str , Any ]]:
208
- if isinstance (part , SystemPromptPart ):
209
- return 'gen_ai.system.message' , {'content' : part .content , 'role' : 'system' }
210
- elif isinstance (part , UserPromptPart ):
211
- return 'gen_ai.user.message' , {'content' : part .content , 'role' : 'user' }
212
- elif isinstance (part , ToolReturnPart ):
213
- return 'gen_ai.tool.message' , {'content' : part .content , 'role' : 'tool' , 'id' : part .tool_call_id }
214
- elif isinstance (part , RetryPromptPart ):
215
- if part .tool_name is None :
216
- return 'gen_ai.user.message' , {'content' : part .model_response (), 'role' : 'user' }
173
+ for event in events :
174
+ self .event_logger .emit (event )
217
175
else :
218
- return 'gen_ai.tool.message' , {'content' : part .model_response (), 'role' : 'tool' , 'id' : part .tool_call_id }
219
- else :
220
- return '' , {}
221
-
222
-
223
- def _response_bodies (message : ModelResponse ) -> list [dict [str , Any ]]:
224
- body : dict [str , Any ] = {'role' : 'assistant' }
225
- result = [body ]
226
- for part in message .parts :
227
- if isinstance (part , ToolCallPart ):
228
- body .setdefault ('tool_calls' , []).append (
176
+ attr_name = 'events'
177
+ span .set_attributes (
229
178
{
230
- 'id' : part .tool_call_id ,
231
- 'type' : 'function' , # TODO https://github.com/pydantic/pydantic-ai/issues/888
232
- 'function' : {
233
- 'name' : part .tool_name ,
234
- 'arguments' : part .args ,
235
- },
179
+ attr_name : json .dumps ([self .event_to_dict (event ) for event in events ]),
180
+ 'logfire.json_schema' : json .dumps (
181
+ {
182
+ 'type' : 'object' ,
183
+ 'properties' : {attr_name : {'type' : 'array' }},
184
+ }
185
+ ),
236
186
}
237
187
)
238
- elif isinstance (part , TextPart ):
239
- if body .get ('content' ):
240
- body = {'role' : 'assistant' }
241
- result .append (body )
242
- body ['content' ] = part .content
243
188
244
- return result
189
+ @staticmethod
190
+ def event_to_dict (event : Event ) -> dict [str , Any ]:
191
+ if not event .body :
192
+ body = {}
193
+ elif isinstance (event .body , Mapping ):
194
+ body = event .body # type: ignore
195
+ else :
196
+ body = {'body' : event .body }
197
+ return {** body , ** (event .attributes or {})}
198
+
199
+ @staticmethod
200
+ def messages_to_otel_events (messages : list [ModelMessage ]) -> list [Event ]:
201
+ result : list [Event ] = []
202
+ for message in messages :
203
+ if isinstance (message , ModelRequest ):
204
+ for part in message .parts :
205
+ if hasattr (part , 'otel_event' ):
206
+ result .append (part .otel_event ())
207
+ elif isinstance (message , ModelResponse ):
208
+ result .extend (message .otel_events ())
209
+ for event in result :
210
+ try :
211
+ event .body = ANY_ADAPTER .dump_python (event .body , mode = 'json' )
212
+ except Exception :
213
+ try :
214
+ event .body = str (event .body )
215
+ except Exception :
216
+ event .body = 'Unable to serialize event body'
217
+ return result
0 commit comments