|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from collections.abc import AsyncIterator |
| 4 | +from contextlib import asynccontextmanager |
| 5 | +from datetime import datetime |
| 6 | + |
3 | 7 | import pytest
|
4 | 8 | from dirty_equals import IsJson
|
5 | 9 | from inline_snapshot import snapshot
|
|
8 | 12 | ModelMessage,
|
9 | 13 | ModelRequest,
|
10 | 14 | ModelResponse,
|
| 15 | + ModelResponseStreamEvent, |
| 16 | + PartDeltaEvent, |
| 17 | + PartStartEvent, |
11 | 18 | RetryPromptPart,
|
12 | 19 | SystemPromptPart,
|
13 | 20 | TextPart,
|
| 21 | + TextPartDelta, |
14 | 22 | ToolCallPart,
|
15 | 23 | ToolReturnPart,
|
16 | 24 | UserPromptPart,
|
17 | 25 | )
|
18 |
| -from pydantic_ai.models import Model, ModelRequestParameters |
| 26 | +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse |
19 | 27 | from pydantic_ai.models.instrumented import InstrumentedModel
|
20 | 28 | from pydantic_ai.settings import ModelSettings
|
21 | 29 | from pydantic_ai.usage import Usage
|
@@ -62,6 +70,30 @@ async def request(
|
62 | 70 | Usage(request_tokens=100, response_tokens=200),
|
63 | 71 | )
|
64 | 72 |
|
| 73 | + @asynccontextmanager |
| 74 | + async def request_stream( |
| 75 | + self, |
| 76 | + messages: list[ModelMessage], |
| 77 | + model_settings: ModelSettings | None, |
| 78 | + model_request_parameters: ModelRequestParameters, |
| 79 | + ) -> AsyncIterator[StreamedResponse]: |
| 80 | + yield MyResponseStream() |
| 81 | + |
| 82 | + |
| 83 | +class MyResponseStream(StreamedResponse): |
| 84 | + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: |
| 85 | + self._usage = Usage(request_tokens=300, response_tokens=400) |
| 86 | + yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') |
| 87 | + yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') |
| 88 | + |
| 89 | + @property |
| 90 | + def model_name(self) -> str: |
| 91 | + return 'my_model_123' |
| 92 | + |
| 93 | + @property |
| 94 | + def timestamp(self) -> datetime: |
| 95 | + return datetime(2022, 1, 1) |
| 96 | + |
65 | 97 |
|
66 | 98 | @pytest.mark.anyio
|
67 | 99 | async def test_instrumented_model(capfire: CaptureLogfire):
|
@@ -322,3 +354,205 @@ async def test_instrumented_model_not_recording(capfire: CaptureLogfire):
|
322 | 354 | )
|
323 | 355 |
|
324 | 356 | assert capfire.exporter.exported_spans_as_dict() == snapshot([])
|
| 357 | + |
| 358 | + |
| 359 | +@pytest.mark.anyio |
| 360 | +async def test_instrumented_model_stream(capfire: CaptureLogfire): |
| 361 | + model = InstrumentedModel(MyModel()) |
| 362 | + |
| 363 | + messages: list[ModelMessage] = [ |
| 364 | + ModelRequest( |
| 365 | + parts=[ |
| 366 | + UserPromptPart('user_prompt'), |
| 367 | + ] |
| 368 | + ), |
| 369 | + ] |
| 370 | + async with model.request_stream( |
| 371 | + messages, |
| 372 | + model_settings=ModelSettings(temperature=1), |
| 373 | + model_request_parameters=ModelRequestParameters( |
| 374 | + function_tools=[], |
| 375 | + allow_text_result=True, |
| 376 | + result_tools=[], |
| 377 | + ), |
| 378 | + ) as response_stream: |
| 379 | + assert [event async for event in response_stream] == snapshot( |
| 380 | + [ |
| 381 | + PartStartEvent(index=0, part=TextPart(content='text1')), |
| 382 | + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')), |
| 383 | + ] |
| 384 | + ) |
| 385 | + |
| 386 | + assert capfire.exporter.exported_spans_as_dict() == snapshot( |
| 387 | + [ |
| 388 | + { |
| 389 | + 'name': 'gen_ai.user.message', |
| 390 | + 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, |
| 391 | + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 392 | + 'start_time': 2000000000, |
| 393 | + 'end_time': 2000000000, |
| 394 | + 'attributes': { |
| 395 | + 'logfire.span_type': 'log', |
| 396 | + 'logfire.level_num': 9, |
| 397 | + 'logfire.msg_template': 'gen_ai.user.message', |
| 398 | + 'logfire.msg': 'gen_ai.user.message', |
| 399 | + 'code.filepath': 'test_instrumented.py', |
| 400 | + 'code.function': 'test_instrumented_model_stream', |
| 401 | + 'code.lineno': 123, |
| 402 | + 'gen_ai.system': 'my_system', |
| 403 | + 'content': 'user_prompt', |
| 404 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}', |
| 405 | + }, |
| 406 | + }, |
| 407 | + { |
| 408 | + 'name': 'gen_ai.choice', |
| 409 | + 'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False}, |
| 410 | + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 411 | + 'start_time': 3000000000, |
| 412 | + 'end_time': 3000000000, |
| 413 | + 'attributes': { |
| 414 | + 'logfire.span_type': 'log', |
| 415 | + 'logfire.level_num': 9, |
| 416 | + 'logfire.msg_template': 'gen_ai.choice', |
| 417 | + 'logfire.msg': 'gen_ai.choice', |
| 418 | + 'code.filepath': 'test_instrumented.py', |
| 419 | + 'code.function': 'test_instrumented_model_stream', |
| 420 | + 'code.lineno': 123, |
| 421 | + 'gen_ai.system': 'my_system', |
| 422 | + 'index': 0, |
| 423 | + 'message': '{"content":"text1text2"}', |
| 424 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}', |
| 425 | + }, |
| 426 | + }, |
| 427 | + { |
| 428 | + 'name': 'chat my_model', |
| 429 | + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 430 | + 'parent': None, |
| 431 | + 'start_time': 1000000000, |
| 432 | + 'end_time': 4000000000, |
| 433 | + 'attributes': { |
| 434 | + 'code.filepath': 'test_instrumented.py', |
| 435 | + 'code.function': 'test_instrumented_model_stream', |
| 436 | + 'code.lineno': 123, |
| 437 | + 'gen_ai.operation.name': 'chat', |
| 438 | + 'gen_ai.system': 'my_system', |
| 439 | + 'gen_ai.request.model': 'my_model', |
| 440 | + 'gen_ai.request.temperature': 1, |
| 441 | + 'logfire.msg_template': 'chat my_model', |
| 442 | + 'logfire.msg': 'chat my_model', |
| 443 | + 'logfire.span_type': 'span', |
| 444 | + 'gen_ai.response.model': 'my_model_123', |
| 445 | + 'gen_ai.usage.input_tokens': 300, |
| 446 | + 'gen_ai.usage.output_tokens': 400, |
| 447 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}', |
| 448 | + }, |
| 449 | + }, |
| 450 | + ] |
| 451 | + ) |
| 452 | + |
| 453 | + |
| 454 | +@pytest.mark.anyio |
| 455 | +async def test_instrumented_model_stream_break(capfire: CaptureLogfire): |
| 456 | + model = InstrumentedModel(MyModel()) |
| 457 | + |
| 458 | + messages: list[ModelMessage] = [ |
| 459 | + ModelRequest( |
| 460 | + parts=[ |
| 461 | + UserPromptPart('user_prompt'), |
| 462 | + ] |
| 463 | + ), |
| 464 | + ] |
| 465 | + |
| 466 | + with pytest.raises(RuntimeError): |
| 467 | + async with model.request_stream( |
| 468 | + messages, |
| 469 | + model_settings=ModelSettings(temperature=1), |
| 470 | + model_request_parameters=ModelRequestParameters( |
| 471 | + function_tools=[], |
| 472 | + allow_text_result=True, |
| 473 | + result_tools=[], |
| 474 | + ), |
| 475 | + ) as response_stream: |
| 476 | + async for event in response_stream: |
| 477 | + assert event == PartStartEvent(index=0, part=TextPart(content='text1')) |
| 478 | + raise RuntimeError |
| 479 | + |
| 480 | + assert capfire.exporter.exported_spans_as_dict() == snapshot( |
| 481 | + [ |
| 482 | + { |
| 483 | + 'name': 'gen_ai.user.message', |
| 484 | + 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, |
| 485 | + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 486 | + 'start_time': 2000000000, |
| 487 | + 'end_time': 2000000000, |
| 488 | + 'attributes': { |
| 489 | + 'logfire.span_type': 'log', |
| 490 | + 'logfire.level_num': 9, |
| 491 | + 'logfire.msg_template': 'gen_ai.user.message', |
| 492 | + 'logfire.msg': 'gen_ai.user.message', |
| 493 | + 'code.filepath': 'test_instrumented.py', |
| 494 | + 'code.function': 'test_instrumented_model_stream_break', |
| 495 | + 'code.lineno': 123, |
| 496 | + 'gen_ai.system': 'my_system', |
| 497 | + 'content': 'user_prompt', |
| 498 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"content":{}}}', |
| 499 | + }, |
| 500 | + }, |
| 501 | + { |
| 502 | + 'name': 'gen_ai.choice', |
| 503 | + 'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False}, |
| 504 | + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 505 | + 'start_time': 3000000000, |
| 506 | + 'end_time': 3000000000, |
| 507 | + 'attributes': { |
| 508 | + 'logfire.span_type': 'log', |
| 509 | + 'logfire.level_num': 9, |
| 510 | + 'logfire.msg_template': 'gen_ai.choice', |
| 511 | + 'logfire.msg': 'gen_ai.choice', |
| 512 | + 'code.filepath': 'test_instrumented.py', |
| 513 | + 'code.function': 'test_instrumented_model_stream_break', |
| 514 | + 'code.lineno': 123, |
| 515 | + 'gen_ai.system': 'my_system', |
| 516 | + 'index': 0, |
| 517 | + 'message': '{"content":"text1"}', |
| 518 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.system":{},"index":{},"message":{"type":"object"}}}', |
| 519 | + }, |
| 520 | + }, |
| 521 | + { |
| 522 | + 'name': 'chat my_model', |
| 523 | + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, |
| 524 | + 'parent': None, |
| 525 | + 'start_time': 1000000000, |
| 526 | + 'end_time': 5000000000, |
| 527 | + 'attributes': { |
| 528 | + 'code.filepath': 'test_instrumented.py', |
| 529 | + 'code.function': 'test_instrumented_model_stream_break', |
| 530 | + 'code.lineno': 123, |
| 531 | + 'gen_ai.operation.name': 'chat', |
| 532 | + 'gen_ai.system': 'my_system', |
| 533 | + 'gen_ai.request.model': 'my_model', |
| 534 | + 'gen_ai.request.temperature': 1, |
| 535 | + 'logfire.msg_template': 'chat my_model', |
| 536 | + 'logfire.msg': 'chat my_model', |
| 537 | + 'logfire.span_type': 'span', |
| 538 | + 'gen_ai.response.model': 'my_model_123', |
| 539 | + 'gen_ai.usage.input_tokens': 300, |
| 540 | + 'gen_ai.usage.output_tokens': 400, |
| 541 | + 'logfire.level_num': 17, |
| 542 | + 'logfire.json_schema': '{"type":"object","properties":{"gen_ai.operation.name":{},"gen_ai.system":{},"gen_ai.request.model":{},"gen_ai.request.temperature":{},"gen_ai.response.model":{},"gen_ai.usage.input_tokens":{},"gen_ai.usage.output_tokens":{}}}', |
| 543 | + }, |
| 544 | + 'events': [ |
| 545 | + { |
| 546 | + 'name': 'exception', |
| 547 | + 'timestamp': 4000000000, |
| 548 | + 'attributes': { |
| 549 | + 'exception.type': 'RuntimeError', |
| 550 | + 'exception.message': '', |
| 551 | + 'exception.stacktrace': 'RuntimeError', |
| 552 | + 'exception.escaped': 'True', |
| 553 | + }, |
| 554 | + } |
| 555 | + ], |
| 556 | + }, |
| 557 | + ] |
| 558 | + ) |
0 commit comments