Skip to content

Commit 28cc2af

Browse files
authored
Add attributes mode to InstrumentedModel (#1010)
1 parent 2972680 commit 28cc2af

File tree

3 files changed

+189
-22
lines changed

3 files changed

+189
-22
lines changed

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from collections.abc import AsyncIterator, Iterator
45
from contextlib import asynccontextmanager, contextmanager
56
from dataclasses import dataclass, field
@@ -9,6 +10,7 @@
910
import logfire_api
1011
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
1112
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
13+
from opentelemetry.util.types import AttributeValue
1214

1315
from ..messages import (
1416
ModelMessage,
@@ -46,40 +48,42 @@
4648
'frequency_penalty',
4749
)
4850

49-
NOT_GIVEN = object()
50-
5151

5252
@dataclass
5353
class InstrumentedModel(WrapperModel):
54-
"""Model which is instrumented with logfire."""
54+
"""Model which is instrumented with OpenTelemetry."""
5555

5656
tracer: Tracer = field(repr=False)
5757
event_logger: EventLogger = field(repr=False)
58+
event_mode: Literal['attributes', 'logs'] = 'attributes'
5859

5960
def __init__(
6061
self,
6162
wrapped: Model | KnownModelName,
6263
tracer_provider: TracerProvider | None = None,
6364
event_logger_provider: EventLoggerProvider | None = None,
65+
event_mode: Literal['attributes', 'logs'] = 'attributes',
6466
):
6567
super().__init__(wrapped)
6668
tracer_provider = tracer_provider or get_tracer_provider()
6769
event_logger_provider = event_logger_provider or get_event_logger_provider()
6870
self.tracer = tracer_provider.get_tracer('pydantic-ai')
6971
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
72+
self.event_mode = event_mode
7073

7174
@classmethod
7275
def from_logfire(
7376
cls,
7477
wrapped: Model | KnownModelName,
7578
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
79+
event_mode: Literal['attributes', 'logs'] = 'attributes',
7680
) -> InstrumentedModel:
7781
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
7882
event_provider = logfire_instance.config.get_event_logger_provider()
7983
else:
8084
event_provider = None
8185
tracer_provider = logfire_instance.config.get_tracer_provider()
82-
return cls(wrapped, tracer_provider, event_provider)
86+
return cls(wrapped, tracer_provider, event_provider, event_mode)
8387

8488
async def request(
8589
self,
@@ -111,7 +115,7 @@ async def request_stream(
111115
finish(response_stream.get(), response_stream.usage())
112116

113117
@contextmanager
114-
def _instrument(
118+
def _instrument( # noqa: C901
115119
self,
116120
messages: list[ModelMessage],
117121
model_settings: ModelSettings | None,
@@ -126,18 +130,19 @@ def _instrument(
126130
# - server.port: to parse from the base_url
127131
# - error.type: unclear if we should do something here or just always rely on span exceptions
128132
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
129-
attributes: dict[str, Any] = {
133+
attributes: dict[str, AttributeValue] = {
130134
'gen_ai.operation.name': operation,
131135
'gen_ai.system': system,
132136
'gen_ai.request.model': model_name,
133137
}
134138

135139
if model_settings:
136140
for key in MODEL_SETTING_ATTRIBUTES:
137-
if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
141+
if isinstance(value := model_settings.get(key), (float, int)):
138142
attributes[f'gen_ai.request.{key}'] = value
139143

140-
emit_event = partial(self._emit_event, system)
144+
events_list = []
145+
emit_event = partial(self._emit_event, system, events_list)
141146

142147
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
143148
if span.is_recording():
@@ -167,22 +172,36 @@ def finish(response: ModelResponse, usage: Usage):
167172
)
168173
span.set_attributes(
169174
{
170-
k: v
171-
for k, v in {
172-
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
173-
# https://github.com/pydantic/pydantic-ai/issues/886
174-
'gen_ai.response.model': response.model_name or model_name,
175-
'gen_ai.usage.input_tokens': usage.request_tokens,
176-
'gen_ai.usage.output_tokens': usage.response_tokens,
177-
}.items()
178-
if v is not None
175+
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
176+
# https://github.com/pydantic/pydantic-ai/issues/886
177+
'gen_ai.response.model': response.model_name or model_name,
178+
**usage.opentelemetry_attributes(),
179179
}
180180
)
181+
if events_list:
182+
attr_name = 'events'
183+
span.set_attributes(
184+
{
185+
attr_name: json.dumps(events_list),
186+
'logfire.json_schema': json.dumps(
187+
{
188+
'type': 'object',
189+
'properties': {attr_name: {'type': 'array'}},
190+
}
191+
),
192+
}
193+
)
181194

182195
yield finish
183196

184-
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
185-
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
197+
def _emit_event(
198+
self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
199+
) -> None:
200+
attributes = {'gen_ai.system': system}
201+
if self.event_mode == 'logs':
202+
self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
203+
else:
204+
events_list.append({'event.name': event_name, **body, **attributes})
186205

187206

188207
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ def __add__(self, other: Usage) -> Usage:
5656
new_usage.incr(other)
5757
return new_usage
5858

59+
def opentelemetry_attributes(self) -> dict[str, int]:
60+
"""Get the token limits as OpenTelemetry attributes."""
61+
result = {
62+
'gen_ai.usage.input_tokens': self.request_tokens,
63+
'gen_ai.usage.output_tokens': self.response_tokens,
64+
}
65+
for key, value in (self.details or {}).items():
66+
result[f'gen_ai.usage.details.{key}'] = value
67+
return {k: v for k, v in result.items() if v is not None}
68+
5969

6070
@dataclass
6171
class UsageLimits:

tests/models/test_instrumented.py

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime
66

77
import pytest
8+
from dirty_equals import IsJson
89
from inline_snapshot import snapshot
910
from logfire_api import DEFAULT_LOGFIRE_INSTANCE
1011

@@ -105,7 +106,7 @@ def timestamp(self) -> datetime:
105106
@pytest.mark.anyio
106107
@requires_logfire_events
107108
async def test_instrumented_model(capfire: CaptureLogfire):
108-
model = InstrumentedModel.from_logfire(MyModel())
109+
model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
109110
assert model.system == 'my_system'
110111
assert model.model_name == 'my_model'
111112

@@ -323,7 +324,7 @@ async def test_instrumented_model_not_recording():
323324
@pytest.mark.anyio
324325
@requires_logfire_events
325326
async def test_instrumented_model_stream(capfire: CaptureLogfire):
326-
model = InstrumentedModel.from_logfire(MyModel())
327+
model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
327328

328329
messages: list[ModelMessage] = [
329330
ModelRequest(
@@ -405,7 +406,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire):
405406
@pytest.mark.anyio
406407
@requires_logfire_events
407408
async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
408-
model = InstrumentedModel.from_logfire(MyModel())
409+
model = InstrumentedModel.from_logfire(MyModel(), event_mode='logs')
409410

410411
messages: list[ModelMessage] = [
411412
ModelRequest(
@@ -494,3 +495,140 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
494495
},
495496
]
496497
)
498+
499+
500+
@pytest.mark.anyio
501+
async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
502+
model = InstrumentedModel(MyModel(), event_mode='attributes')
503+
assert model.system == 'my_system'
504+
assert model.model_name == 'my_model'
505+
506+
messages = [
507+
ModelRequest(
508+
parts=[
509+
SystemPromptPart('system_prompt'),
510+
UserPromptPart('user_prompt'),
511+
ToolReturnPart('tool3', 'tool_return_content', 'tool_call_3'),
512+
RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'),
513+
RetryPromptPart('retry_prompt2'),
514+
{}, # test unexpected parts # type: ignore
515+
]
516+
),
517+
ModelResponse(
518+
parts=[
519+
TextPart('text3'),
520+
]
521+
),
522+
]
523+
await model.request(
524+
messages,
525+
model_settings=ModelSettings(temperature=1),
526+
model_request_parameters=ModelRequestParameters(
527+
function_tools=[],
528+
allow_text_result=True,
529+
result_tools=[],
530+
),
531+
)
532+
533+
assert capfire.exporter.exported_spans_as_dict() == snapshot(
534+
[
535+
{
536+
'name': 'chat my_model',
537+
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
538+
'parent': None,
539+
'start_time': 1000000000,
540+
'end_time': 2000000000,
541+
'attributes': {
542+
'gen_ai.operation.name': 'chat',
543+
'gen_ai.system': 'my_system',
544+
'gen_ai.request.model': 'my_model',
545+
'gen_ai.request.temperature': 1,
546+
'logfire.msg': 'chat my_model',
547+
'logfire.span_type': 'span',
548+
'gen_ai.response.model': 'my_model_123',
549+
'gen_ai.usage.input_tokens': 100,
550+
'gen_ai.usage.output_tokens': 200,
551+
'events': IsJson(
552+
snapshot(
553+
[
554+
{
555+
'event.name': 'gen_ai.system.message',
556+
'content': 'system_prompt',
557+
'role': 'system',
558+
'gen_ai.system': 'my_system',
559+
},
560+
{
561+
'event.name': 'gen_ai.user.message',
562+
'content': 'user_prompt',
563+
'role': 'user',
564+
'gen_ai.system': 'my_system',
565+
},
566+
{
567+
'event.name': 'gen_ai.tool.message',
568+
'content': 'tool_return_content',
569+
'role': 'tool',
570+
'id': 'tool_call_3',
571+
'gen_ai.system': 'my_system',
572+
},
573+
{
574+
'event.name': 'gen_ai.tool.message',
575+
'content': """\
576+
retry_prompt1
577+
578+
Fix the errors and try again.\
579+
""",
580+
'role': 'tool',
581+
'id': 'tool_call_4',
582+
'gen_ai.system': 'my_system',
583+
},
584+
{
585+
'event.name': 'gen_ai.user.message',
586+
'content': """\
587+
retry_prompt2
588+
589+
Fix the errors and try again.\
590+
""",
591+
'role': 'user',
592+
'gen_ai.system': 'my_system',
593+
},
594+
{
595+
'event.name': 'gen_ai.assistant.message',
596+
'role': 'assistant',
597+
'content': 'text3',
598+
'gen_ai.system': 'my_system',
599+
},
600+
{
601+
'event.name': 'gen_ai.choice',
602+
'index': 0,
603+
'message': {
604+
'role': 'assistant',
605+
'content': 'text1',
606+
'tool_calls': [
607+
{
608+
'id': 'tool_call_1',
609+
'type': 'function',
610+
'function': {'name': 'tool1', 'arguments': 'args1'},
611+
},
612+
{
613+
'id': 'tool_call_2',
614+
'type': 'function',
615+
'function': {'name': 'tool2', 'arguments': {'args2': 3}},
616+
},
617+
],
618+
},
619+
'gen_ai.system': 'my_system',
620+
},
621+
{
622+
'event.name': 'gen_ai.choice',
623+
'index': 0,
624+
'message': {'role': 'assistant', 'content': 'text2'},
625+
'gen_ai.system': 'my_system',
626+
},
627+
]
628+
)
629+
),
630+
'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}}}',
631+
},
632+
},
633+
]
634+
)

0 commit comments

Comments
 (0)