|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import re
|
| 4 | +from collections import defaultdict |
4 | 5 | from dataclasses import dataclass, replace
|
5 | 6 | from typing import TypeVar
|
6 | 7 | from unittest.mock import AsyncMock
|
|
10 | 11 |
|
11 | 12 | from pydantic_ai._run_context import RunContext
|
12 | 13 | from pydantic_ai._tool_manager import ToolManager
|
13 |
| -from pydantic_ai.exceptions import UserError |
| 14 | +from pydantic_ai.exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UserError |
14 | 15 | from pydantic_ai.messages import ToolCallPart
|
15 | 16 | from pydantic_ai.models.test import TestModel
|
16 | 17 | from pydantic_ai.tools import ToolDefinition
|
@@ -494,3 +495,134 @@ async def test_context_manager_failed_initialization():
|
494 | 495 | pass
|
495 | 496 |
|
496 | 497 | assert server1.is_running is False
|
| 498 | + |
| 499 | + |
| 500 | +async def test_tool_manager_retry_logic(): |
| 501 | + """Test the retry logic with failed_tools and for_run_step method.""" |
| 502 | + |
| 503 | + @dataclass |
| 504 | + class TestDeps: |
| 505 | + pass |
| 506 | + |
| 507 | + # Create a toolset with tools that can fail |
| 508 | + toolset = FunctionToolset[TestDeps](max_retries=2) |
| 509 | + call_count: defaultdict[str, int] = defaultdict(int) |
| 510 | + |
| 511 | + @toolset.tool |
| 512 | + def failing_tool(x: int) -> int: |
| 513 | + """A tool that always fails""" |
| 514 | + call_count['failing_tool'] += 1 |
| 515 | + raise ModelRetry('This tool always fails') |
| 516 | + |
| 517 | + @toolset.tool |
| 518 | + def other_tool(x: int) -> int: |
| 519 | + """A tool that works""" |
| 520 | + call_count['other_tool'] += 1 |
| 521 | + return x * 2 |
| 522 | + |
| 523 | + # Create initial context and tool manager |
| 524 | + initial_context = build_run_context(TestDeps()) |
| 525 | + tool_manager = await ToolManager[TestDeps].build(toolset, initial_context) |
| 526 | + |
| 527 | + # Initially no failed tools |
| 528 | + assert tool_manager.failed_tools == set() |
| 529 | + assert initial_context.retries == {} |
| 530 | + |
| 531 | + # Call the failing tool - should add to failed_tools |
| 532 | + with pytest.raises(ToolRetryError): |
| 533 | + await tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) |
| 534 | + |
| 535 | + assert tool_manager.failed_tools == {'failing_tool'} |
| 536 | + assert call_count['failing_tool'] == 1 |
| 537 | + |
| 538 | + # Call the working tool - should not add to failed_tools |
| 539 | + result = await tool_manager.handle_call(ToolCallPart(tool_name='other_tool', args={'x': 3})) |
| 540 | + assert result == 6 |
| 541 | + assert tool_manager.failed_tools == {'failing_tool'} # unchanged |
| 542 | + assert call_count['other_tool'] == 1 |
| 543 | + |
| 544 | + # Test for_run_step - should create new tool manager with updated retry counts |
| 545 | + new_context = build_run_context(TestDeps()) |
| 546 | + new_tool_manager = await tool_manager.for_run_step(new_context) |
| 547 | + |
| 548 | + # The new tool manager should have retry count for the failed tool |
| 549 | + assert new_tool_manager.ctx.retries == {'failing_tool': 1} |
| 550 | + assert new_tool_manager.failed_tools == set() # reset for new run step |
| 551 | + |
| 552 | + # Call the failing tool again in the new manager - should have retry=1 |
| 553 | + with pytest.raises(ToolRetryError): |
| 554 | + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) |
| 555 | + |
| 556 | + # Call the failing tool another time in the new manager |
| 557 | + with pytest.raises(ToolRetryError): |
| 558 | + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) |
| 559 | + |
| 560 | + # Call the failing tool a third time in the new manager |
| 561 | + with pytest.raises(ToolRetryError): |
| 562 | + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) |
| 563 | + |
| 564 | + assert new_tool_manager.failed_tools == {'failing_tool'} |
| 565 | + assert call_count['failing_tool'] == 4 |
| 566 | + |
| 567 | + # Create another run step |
| 568 | + another_context = build_run_context(TestDeps()) |
| 569 | + another_tool_manager = await new_tool_manager.for_run_step(another_context) |
| 570 | + |
| 571 | + # Should now have retry count of 2 for failing_tool |
| 572 | + assert another_tool_manager.ctx.retries == {'failing_tool': 2} |
| 573 | + assert another_tool_manager.failed_tools == set() |
| 574 | + |
| 575 | + # Call the failing tool _again_, now we should finally hit the limit |
| 576 | + with pytest.raises(UnexpectedModelBehavior, match="Tool 'failing_tool' exceeded max retries count of 2"): |
| 577 | + await another_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) |
| 578 | + |
| 579 | + |
| 580 | +async def test_tool_manager_multiple_failed_tools(): |
| 581 | + """Test retry logic when multiple tools fail in the same run step.""" |
| 582 | + |
| 583 | + @dataclass |
| 584 | + class TestDeps: |
| 585 | + pass |
| 586 | + |
| 587 | + toolset = FunctionToolset[TestDeps]() |
| 588 | + |
| 589 | + @toolset.tool |
| 590 | + def tool_a(x: int) -> int: |
| 591 | + """Tool A that fails""" |
| 592 | + raise ModelRetry('Tool A fails') |
| 593 | + |
| 594 | + @toolset.tool |
| 595 | + def tool_b(x: int) -> int: |
| 596 | + """Tool B that fails""" |
| 597 | + raise ModelRetry('Tool B fails') |
| 598 | + |
| 599 | + @toolset.tool |
| 600 | + def tool_c(x: int) -> int: |
| 601 | + """Tool C that works""" |
| 602 | + return x * 3 |
| 603 | + |
| 604 | + # Create tool manager |
| 605 | + context = build_run_context(TestDeps()) |
| 606 | + tool_manager = await ToolManager[TestDeps].build(toolset, context) |
| 607 | + |
| 608 | + # Call tool_a - should fail and be added to failed_tools |
| 609 | + with pytest.raises(ToolRetryError): |
| 610 | + await tool_manager.handle_call(ToolCallPart(tool_name='tool_a', args={'x': 1})) |
| 611 | + assert tool_manager.failed_tools == {'tool_a'} |
| 612 | + |
| 613 | + # Call tool_b - should also fail and be added to failed_tools |
| 614 | + with pytest.raises(ToolRetryError): |
| 615 | + await tool_manager.handle_call(ToolCallPart(tool_name='tool_b', args={'x': 1})) |
| 616 | + assert tool_manager.failed_tools == {'tool_a', 'tool_b'} |
| 617 | + |
| 618 | + # Call tool_c - should succeed and not be added to failed_tools |
| 619 | + result = await tool_manager.handle_call(ToolCallPart(tool_name='tool_c', args={'x': 2})) |
| 620 | + assert result == 6 |
| 621 | + assert tool_manager.failed_tools == {'tool_a', 'tool_b'} # unchanged |
| 622 | + |
| 623 | + # Create next run step - should have retry counts for both failed tools |
| 624 | + new_context = build_run_context(TestDeps()) |
| 625 | + new_tool_manager = await tool_manager.for_run_step(new_context) |
| 626 | + |
| 627 | + assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1} |
| 628 | + assert new_tool_manager.failed_tools == set() # reset for new run step |
0 commit comments