Skip to content

Commit 8d5b47a

Browse files
authored
Add request_stream to InstrumentedModel (#922)
1 parent 8fcf8c9 commit 8d5b47a

File tree

3 files changed

+289
-10
lines changed

3 files changed

+289
-10
lines changed

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterator
4+
from contextlib import asynccontextmanager, contextmanager
35
from dataclasses import dataclass
46
from functools import partial
57
from typing import Any, Literal
@@ -20,7 +22,7 @@
2022
)
2123
from ..settings import ModelSettings
2224
from ..usage import Usage
23-
from . import ModelRequestParameters
25+
from . import ModelRequestParameters, StreamedResponse
2426
from .wrapper import WrapperModel
2527

2628
MODEL_SETTING_ATTRIBUTES: tuple[
@@ -60,6 +62,35 @@ async def request(
6062
model_settings: ModelSettings | None,
6163
model_request_parameters: ModelRequestParameters,
6264
) -> tuple[ModelResponse, Usage]:
65+
with self._instrument(messages, model_settings) as finish:
66+
response, usage = await super().request(messages, model_settings, model_request_parameters)
67+
finish(response, usage)
68+
return response, usage
69+
70+
@asynccontextmanager
71+
async def request_stream(
72+
self,
73+
messages: list[ModelMessage],
74+
model_settings: ModelSettings | None,
75+
model_request_parameters: ModelRequestParameters,
76+
) -> AsyncIterator[StreamedResponse]:
77+
with self._instrument(messages, model_settings) as finish:
78+
response_stream: StreamedResponse | None = None
79+
try:
80+
async with super().request_stream(
81+
messages, model_settings, model_request_parameters
82+
) as response_stream:
83+
yield response_stream
84+
finally:
85+
if response_stream:
86+
finish(response_stream.get(), response_stream.usage())
87+
88+
@contextmanager
89+
def _instrument(
90+
self,
91+
messages: list[ModelMessage],
92+
model_settings: ModelSettings | None,
93+
):
6394
operation = 'chat'
6495
model_name = self.model_name
6596
span_name = f'{operation} {model_name}'
@@ -95,17 +126,18 @@ async def request(
95126
for body in _response_bodies(message):
96127
emit_event('gen_ai.assistant.message', body)
97128

98-
response, usage = await super().request(messages, model_settings, model_request_parameters)
129+
def finish(response: ModelResponse, usage: Usage):
130+
if not span.is_recording():
131+
return
99132

100-
if span.is_recording():
101-
for body in _response_bodies(response):
102-
if body:
133+
for response_body in _response_bodies(response):
134+
if response_body:
103135
emit_event(
104136
'gen_ai.choice',
105137
{
106138
# TODO finish_reason
107139
'index': 0,
108-
'message': body,
140+
'message': response_body,
109141
},
110142
)
111143
span.set_attributes(
@@ -122,7 +154,7 @@ async def request(
122154
}
123155
)
124156

125-
return response, usage
157+
yield finish
126158

127159
def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
128160
self.logfire_instance.info(event_name, **{'gen_ai.system': system}, **body)

pydantic_ai_slim/pydantic_ai/models/wrapper.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterator
4+
from contextlib import asynccontextmanager
35
from dataclasses import dataclass
46
from typing import Any
57

6-
from ..messages import ModelResponse
8+
from ..messages import ModelMessage, ModelResponse
9+
from ..settings import ModelSettings
710
from ..usage import Usage
8-
from . import Model
11+
from . import Model, ModelRequestParameters, StreamedResponse
912

1013

1114
@dataclass
@@ -17,6 +20,16 @@ class WrapperModel(Model):
1720
async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
1821
return await self.wrapped.request(*args, **kwargs)
1922

23+
@asynccontextmanager
24+
async def request_stream(
25+
self,
26+
messages: list[ModelMessage],
27+
model_settings: ModelSettings | None,
28+
model_request_parameters: ModelRequestParameters,
29+
) -> AsyncIterator[StreamedResponse]:
30+
async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream:
31+
yield response_stream
32+
2033
@property
2134
def model_name(self) -> str:
2235
return self.wrapped.model_name

tests/models/test_instrumented.py

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterator
4+
from contextlib import asynccontextmanager
5+
from datetime import datetime
6+
37
import pytest
48
from dirty_equals import IsJson
59
from inline_snapshot import snapshot
@@ -8,14 +12,18 @@
812
ModelMessage,
913
ModelRequest,
1014
ModelResponse,
15+
ModelResponseStreamEvent,
16+
PartDeltaEvent,
17+
PartStartEvent,
1118
RetryPromptPart,
1219
SystemPromptPart,
1320
TextPart,
21+
TextPartDelta,
1422
ToolCallPart,
1523
ToolReturnPart,
1624
UserPromptPart,
1725
)
18-
from pydantic_ai.models import Model, ModelRequestParameters
26+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
1927
from pydantic_ai.models.instrumented import InstrumentedModel
2028
from pydantic_ai.settings import ModelSettings
2129
from pydantic_ai.usage import Usage
@@ -62,6 +70,30 @@ async def request(
6270
Usage(request_tokens=100, response_tokens=200),
6371
)
6472

73+
@asynccontextmanager
74+
async def request_stream(
75+
self,
76+
messages: list[ModelMessage],
77+
model_settings: ModelSettings | None,
78+
model_request_parameters: ModelRequestParameters,
79+
) -> AsyncIterator[StreamedResponse]:
80+
yield MyResponseStream()
81+
82+
83+
class MyResponseStream(StreamedResponse):
84+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
85+
self._usage = Usage(request_tokens=300, response_tokens=400)
86+
yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1')
87+
yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2')
88+
89+
@property
90+
def model_name(self) -> str:
91+
return 'my_model_123'
92+
93+
@property
94+
def timestamp(self) -> datetime:
95+
return datetime(2022, 1, 1)
96+
6597

6698
@pytest.mark.anyio
6799
async def test_instrumented_model(capfire: CaptureLogfire):
@@ -322,3 +354,205 @@ async def test_instrumented_model_not_recording(capfire: CaptureLogfire):
322354
)
323355

324356
assert capfire.exporter.exported_spans_as_dict() == snapshot([])
357+
358+
359+
@pytest.mark.anyio
360+
async def test_instrumented_model_stream(capfire: CaptureLogfire):
361+
model = InstrumentedModel(MyModel())
362+
363+
messages: list[ModelMessage] = [
364+
ModelRequest(
365+
parts=[
366+
UserPromptPart('user_prompt'),
367+
]
368+
),
369+
]
370+
async with model.request_stream(
371+
messages,
372+
model_settings=ModelSettings(temperature=1),
373+
model_request_parameters=ModelRequestParameters(
374+
function_tools=[],
375+
allow_text_result=True,
376+
result_tools=[],
377+
),
378+
) as response_stream:
379+
assert [event async for event in response_stream] == snapshot(
380+
[
381+
PartStartEvent(index=0, part=TextPart(content='text1')),
382+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')),
383+
]
384+
)
385+
386+
assert capfire.exporter.exported_spans_as_dict() == snapshot(
387+
[
388+
{
389+
'name': 'gen_ai.user.message',
390+
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
391+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
392+
'start_time': 2000000000,
393+
'end_time': 2000000000,
394+
'attributes': {
395+
'logfire.span_type': 'log',
396+
'logfire.level_num': 9,
397+
'logfire.msg_template': 'gen_ai.user.message',
398+
'logfire.msg': 'gen_ai.user.message',
399+
'code.filepath': 'test_instrumented.py',
400+
'code.function': 'test_instrumented_model_stream',
401+
'code.lineno': 123,
402+
'gen_ai.system': 'my_system',
403+
'content': 'user_prompt',
404+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}',
405+
},
406+
},
407+
{
408+
'name': 'gen_ai.choice',
409+
'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False},
410+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
411+
'start_time': 3000000000,
412+
'end_time': 3000000000,
413+
'attributes': {
414+
'logfire.span_type': 'log',
415+
'logfire.level_num': 9,
416+
'logfire.msg_template': 'gen_ai.choice',
417+
'logfire.msg': 'gen_ai.choice',
418+
'code.filepath': 'test_instrumented.py',
419+
'code.function': 'test_instrumented_model_stream',
420+
'code.lineno': 123,
421+
'gen_ai.system': 'my_system',
422+
'index': 0,
423+
'message': '{"content":"text1text2"}',
424+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}',
425+
},
426+
},
427+
{
428+
'name': 'chat my_model',
429+
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
430+
'parent': None,
431+
'start_time': 1000000000,
432+
'end_time': 4000000000,
433+
'attributes': {
434+
'code.filepath': 'test_instrumented.py',
435+
'code.function': 'test_instrumented_model_stream',
436+
'code.lineno': 123,
437+
'gen_ai.operation.name': 'chat',
438+
'gen_ai.system': 'my_system',
439+
'gen_ai.request.model': 'my_model',
440+
'gen_ai.request.temperature': 1,
441+
'logfire.msg_template': 'chat my_model',
442+
'logfire.msg': 'chat my_model',
443+
'logfire.span_type': 'span',
444+
'gen_ai.response.model': 'my_model_123',
445+
'gen_ai.usage.input_tokens': 300,
446+
'gen_ai.usage.output_tokens': 400,
447+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}',
448+
},
449+
},
450+
]
451+
)
452+
453+
454+
@pytest.mark.anyio
455+
async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
456+
model = InstrumentedModel(MyModel())
457+
458+
messages: list[ModelMessage] = [
459+
ModelRequest(
460+
parts=[
461+
UserPromptPart('user_prompt'),
462+
]
463+
),
464+
]
465+
466+
with pytest.raises(RuntimeError):
467+
async with model.request_stream(
468+
messages,
469+
model_settings=ModelSettings(temperature=1),
470+
model_request_parameters=ModelRequestParameters(
471+
function_tools=[],
472+
allow_text_result=True,
473+
result_tools=[],
474+
),
475+
) as response_stream:
476+
async for event in response_stream:
477+
assert event == PartStartEvent(index=0, part=TextPart(content='text1'))
478+
raise RuntimeError
479+
480+
assert capfire.exporter.exported_spans_as_dict() == snapshot(
481+
[
482+
{
483+
'name': 'gen_ai.user.message',
484+
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
485+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
486+
'start_time': 2000000000,
487+
'end_time': 2000000000,
488+
'attributes': {
489+
'logfire.span_type': 'log',
490+
'logfire.level_num': 9,
491+
'logfire.msg_template': 'gen_ai.user.message',
492+
'logfire.msg': 'gen_ai.user.message',
493+
'code.filepath': 'test_instrumented.py',
494+
'code.function': 'test_instrumented_model_stream_break',
495+
'code.lineno': 123,
496+
'gen_ai.system': 'my_system',
497+
'content': 'user_prompt',
498+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}',
499+
},
500+
},
501+
{
502+
'name': 'gen_ai.choice',
503+
'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False},
504+
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
505+
'start_time': 3000000000,
506+
'end_time': 3000000000,
507+
'attributes': {
508+
'logfire.span_type': 'log',
509+
'logfire.level_num': 9,
510+
'logfire.msg_template': 'gen_ai.choice',
511+
'logfire.msg': 'gen_ai.choice',
512+
'code.filepath': 'test_instrumented.py',
513+
'code.function': 'test_instrumented_model_stream_break',
514+
'code.lineno': 123,
515+
'gen_ai.system': 'my_system',
516+
'index': 0,
517+
'message': '{"content":"text1"}',
518+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}',
519+
},
520+
},
521+
{
522+
'name': 'chat my_model',
523+
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
524+
'parent': None,
525+
'start_time': 1000000000,
526+
'end_time': 5000000000,
527+
'attributes': {
528+
'code.filepath': 'test_instrumented.py',
529+
'code.function': 'test_instrumented_model_stream_break',
530+
'code.lineno': 123,
531+
'gen_ai.operation.name': 'chat',
532+
'gen_ai.system': 'my_system',
533+
'gen_ai.request.model': 'my_model',
534+
'gen_ai.request.temperature': 1,
535+
'logfire.msg_template': 'chat my_model',
536+
'logfire.msg': 'chat my_model',
537+
'logfire.span_type': 'span',
538+
'gen_ai.response.model': 'my_model_123',
539+
'gen_ai.usage.input_tokens': 300,
540+
'gen_ai.usage.output_tokens': 400,
541+
'logfire.level_num': 17,
542+
'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}',
543+
},
544+
'events': [
545+
{
546+
'name': 'exception',
547+
'timestamp': 4000000000,
548+
'attributes': {
549+
'exception.type': 'RuntimeError',
550+
'exception.message': '',
551+
'exception.stacktrace': 'RuntimeError',
552+
'exception.escaped': 'True',
553+
},
554+
}
555+
],
556+
},
557+
]
558+
)

0 commit comments

Comments
 (0)