|
8 | 8 | from unittest.mock import patch |
9 | 9 |
|
10 | 10 | import pytest |
| 11 | +from openai.types.responses import ResponseFunctionToolCall |
11 | 12 | from typing_extensions import TypedDict |
12 | 13 |
|
13 | 14 | from agents import ( |
|
29 | 30 | handoff, |
30 | 31 | ) |
31 | 32 | from agents.agent import ToolsToFinalOutputResult |
32 | | -from agents.tool import FunctionToolResult, function_tool |
| 33 | +from agents.computer import Computer |
| 34 | +from agents.items import RunItem, ToolApprovalItem, ToolCallOutputItem |
| 35 | +from agents.lifecycle import RunHooks |
| 36 | +from agents.run import AgentRunner |
| 37 | +from agents.run_state import RunState |
| 38 | +from agents.tool import ComputerTool, FunctionToolResult, function_tool |
33 | 39 |
|
34 | 40 | from .fake_model import FakeModel |
35 | 41 | from .test_responses import ( |
@@ -699,6 +705,58 @@ def guardrail_function( |
699 | 705 | await Runner.run(agent, input="user_message") |
700 | 706 |
|
701 | 707 |
|
| 708 | +@pytest.mark.asyncio |
| 709 | +async def test_input_guardrail_no_tripwire_continues_execution(): |
| 710 | + """Test input guardrail that doesn't trigger tripwire continues execution.""" |
| 711 | + |
| 712 | + def guardrail_function( |
| 713 | + context: RunContextWrapper[Any], agent: Agent[Any], input: Any |
| 714 | + ) -> GuardrailFunctionOutput: |
| 715 | + return GuardrailFunctionOutput( |
| 716 | + output_info=None, |
| 717 | + tripwire_triggered=False, # Doesn't trigger tripwire |
| 718 | + ) |
| 719 | + |
| 720 | + model = FakeModel() |
| 721 | + model.set_next_output([get_text_message("response")]) |
| 722 | + |
| 723 | + agent = Agent( |
| 724 | + name="test", |
| 725 | + model=model, |
| 726 | + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], |
| 727 | + ) |
| 728 | + |
| 729 | + # Should complete successfully without raising exception |
| 730 | + result = await Runner.run(agent, input="user_message") |
| 731 | + assert result.final_output == "response" |
| 732 | + |
| 733 | + |
| 734 | +@pytest.mark.asyncio |
| 735 | +async def test_output_guardrail_no_tripwire_continues_execution(): |
| 736 | + """Test output guardrail that doesn't trigger tripwire continues execution.""" |
| 737 | + |
| 738 | + def guardrail_function( |
| 739 | + context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any |
| 740 | + ) -> GuardrailFunctionOutput: |
| 741 | + return GuardrailFunctionOutput( |
| 742 | + output_info=None, |
| 743 | + tripwire_triggered=False, # Doesn't trigger tripwire |
| 744 | + ) |
| 745 | + |
| 746 | + model = FakeModel() |
| 747 | + model.set_next_output([get_text_message("response")]) |
| 748 | + |
| 749 | + agent = Agent( |
| 750 | + name="test", |
| 751 | + model=model, |
| 752 | + output_guardrails=[OutputGuardrail(guardrail_function=guardrail_function)], |
| 753 | + ) |
| 754 | + |
| 755 | + # Should complete successfully without raising exception |
| 756 | + result = await Runner.run(agent, input="user_message") |
| 757 | + assert result.final_output == "response" |
| 758 | + |
| 759 | + |
702 | 760 | @function_tool |
703 | 761 | def test_tool_one(): |
704 | 762 | return Foo(bar="tool_one_result") |
@@ -1351,3 +1409,259 @@ async def echo_tool(text: str) -> str: |
1351 | 1409 | assert (await session.get_items()) == expected_items |
1352 | 1410 |
|
1353 | 1411 | session.close() |
| 1412 | + |
| 1413 | + |
| 1414 | +@pytest.mark.asyncio |
| 1415 | +async def test_execute_approved_tools_with_non_function_tool(): |
| 1416 | + """Test _execute_approved_tools handles non-FunctionTool.""" |
| 1417 | + model = FakeModel() |
| 1418 | + |
| 1419 | + # Create a computer tool (not a FunctionTool) |
| 1420 | + class MockComputer(Computer): |
| 1421 | + @property |
| 1422 | + def environment(self) -> str: # type: ignore[override] |
| 1423 | + return "mac" |
| 1424 | + |
| 1425 | + @property |
| 1426 | + def dimensions(self) -> tuple[int, int]: |
| 1427 | + return (1920, 1080) |
| 1428 | + |
| 1429 | + def screenshot(self) -> str: |
| 1430 | + return "screenshot" |
| 1431 | + |
| 1432 | + def click(self, x: int, y: int, button: str) -> None: |
| 1433 | + pass |
| 1434 | + |
| 1435 | + def double_click(self, x: int, y: int) -> None: |
| 1436 | + pass |
| 1437 | + |
| 1438 | + def drag(self, path: list[tuple[int, int]]) -> None: |
| 1439 | + pass |
| 1440 | + |
| 1441 | + def keypress(self, keys: list[str]) -> None: |
| 1442 | + pass |
| 1443 | + |
| 1444 | + def move(self, x: int, y: int) -> None: |
| 1445 | + pass |
| 1446 | + |
| 1447 | + def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: |
| 1448 | + pass |
| 1449 | + |
| 1450 | + def type(self, text: str) -> None: |
| 1451 | + pass |
| 1452 | + |
| 1453 | + def wait(self) -> None: |
| 1454 | + pass |
| 1455 | + |
| 1456 | + computer = MockComputer() |
| 1457 | + computer_tool = ComputerTool(computer=computer) |
| 1458 | + |
| 1459 | + agent = Agent(name="TestAgent", model=model, tools=[computer_tool]) |
| 1460 | + |
| 1461 | + # Create an approved tool call for the computer tool |
| 1462 | + # ComputerTool has name "computer_use_preview" |
| 1463 | + tool_call = get_function_tool_call("computer_use_preview", "{}") |
| 1464 | + assert isinstance(tool_call, ResponseFunctionToolCall) |
| 1465 | + |
| 1466 | + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) |
| 1467 | + |
| 1468 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 1469 | + state = RunState( |
| 1470 | + context=context_wrapper, |
| 1471 | + original_input="test", |
| 1472 | + starting_agent=agent, |
| 1473 | + max_turns=1, |
| 1474 | + ) |
| 1475 | + state.approve(approval_item) |
| 1476 | + |
| 1477 | + generated_items: list[RunItem] = [] |
| 1478 | + |
| 1479 | + # Execute approved tools |
| 1480 | + await AgentRunner._execute_approved_tools_static( |
| 1481 | + agent=agent, |
| 1482 | + interruptions=[approval_item], |
| 1483 | + context_wrapper=context_wrapper, |
| 1484 | + generated_items=generated_items, |
| 1485 | + run_config=RunConfig(), |
| 1486 | + hooks=RunHooks(), |
| 1487 | + ) |
| 1488 | + |
| 1489 | + # Should add error message about tool not being a function tool |
| 1490 | + assert len(generated_items) == 1 |
| 1491 | + assert isinstance(generated_items[0], ToolCallOutputItem) |
| 1492 | + assert "not a function tool" in generated_items[0].output.lower() |
| 1493 | + |
| 1494 | + |
| 1495 | +@pytest.mark.asyncio |
| 1496 | +async def test_execute_approved_tools_with_rejected_tool(): |
| 1497 | + """Test _execute_approved_tools handles rejected tools.""" |
| 1498 | + model = FakeModel() |
| 1499 | + tool_called = False |
| 1500 | + |
| 1501 | + async def test_tool() -> str: |
| 1502 | + nonlocal tool_called |
| 1503 | + tool_called = True |
| 1504 | + return "tool_result" |
| 1505 | + |
| 1506 | + tool = function_tool(test_tool, name_override="test_tool") |
| 1507 | + agent = Agent(name="TestAgent", model=model, tools=[tool]) |
| 1508 | + |
| 1509 | + # Create a rejected tool call |
| 1510 | + tool_call = get_function_tool_call("test_tool", "{}") |
| 1511 | + assert isinstance(tool_call, ResponseFunctionToolCall) |
| 1512 | + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) |
| 1513 | + |
| 1514 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 1515 | + # Reject via RunState |
| 1516 | + state = RunState( |
| 1517 | + context=context_wrapper, |
| 1518 | + original_input="test", |
| 1519 | + starting_agent=agent, |
| 1520 | + max_turns=1, |
| 1521 | + ) |
| 1522 | + state.reject(approval_item) |
| 1523 | + |
| 1524 | + generated_items: list[Any] = [] |
| 1525 | + |
| 1526 | + # Execute approved tools |
| 1527 | + await AgentRunner._execute_approved_tools_static( |
| 1528 | + agent=agent, |
| 1529 | + interruptions=[approval_item], |
| 1530 | + context_wrapper=context_wrapper, |
| 1531 | + generated_items=generated_items, |
| 1532 | + run_config=RunConfig(), |
| 1533 | + hooks=RunHooks(), |
| 1534 | + ) |
| 1535 | + |
| 1536 | + # Should add rejection message |
| 1537 | + assert len(generated_items) == 1 |
| 1538 | + assert "not approved" in generated_items[0].output.lower() |
| 1539 | + assert not tool_called # Tool should not have been executed |
| 1540 | + |
| 1541 | + |
| 1542 | +@pytest.mark.asyncio |
| 1543 | +async def test_execute_approved_tools_with_unclear_status(): |
| 1544 | + """Test _execute_approved_tools handles unclear approval status.""" |
| 1545 | + model = FakeModel() |
| 1546 | + tool_called = False |
| 1547 | + |
| 1548 | + async def test_tool() -> str: |
| 1549 | + nonlocal tool_called |
| 1550 | + tool_called = True |
| 1551 | + return "tool_result" |
| 1552 | + |
| 1553 | + tool = function_tool(test_tool, name_override="test_tool") |
| 1554 | + agent = Agent(name="TestAgent", model=model, tools=[tool]) |
| 1555 | + |
| 1556 | + # Create a tool call with unclear status (neither approved nor rejected) |
| 1557 | + tool_call = get_function_tool_call("test_tool", "{}") |
| 1558 | + assert isinstance(tool_call, ResponseFunctionToolCall) |
| 1559 | + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) |
| 1560 | + |
| 1561 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 1562 | + # Don't approve or reject - status will be None |
| 1563 | + |
| 1564 | + generated_items: list[Any] = [] |
| 1565 | + |
| 1566 | + # Execute approved tools |
| 1567 | + await AgentRunner._execute_approved_tools_static( |
| 1568 | + agent=agent, |
| 1569 | + interruptions=[approval_item], |
| 1570 | + context_wrapper=context_wrapper, |
| 1571 | + generated_items=generated_items, |
| 1572 | + run_config=RunConfig(), |
| 1573 | + hooks=RunHooks(), |
| 1574 | + ) |
| 1575 | + |
| 1576 | + # Should add unclear status message |
| 1577 | + assert len(generated_items) == 1 |
| 1578 | + assert "unclear" in generated_items[0].output.lower() |
| 1579 | + assert not tool_called # Tool should not have been executed |
| 1580 | + |
| 1581 | + |
| 1582 | +@pytest.mark.asyncio |
| 1583 | +async def test_execute_approved_tools_with_missing_tool(): |
| 1584 | + """Test _execute_approved_tools handles missing tools.""" |
| 1585 | + model = FakeModel() |
| 1586 | + agent = Agent(name="TestAgent", model=model) |
| 1587 | + # Agent has no tools |
| 1588 | + |
| 1589 | + # Create an approved tool call for a tool that doesn't exist |
| 1590 | + tool_call = get_function_tool_call("nonexistent_tool", "{}") |
| 1591 | + assert isinstance(tool_call, ResponseFunctionToolCall) |
| 1592 | + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) |
| 1593 | + |
| 1594 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 1595 | + # Approve via RunState |
| 1596 | + state = RunState( |
| 1597 | + context=context_wrapper, |
| 1598 | + original_input="test", |
| 1599 | + starting_agent=agent, |
| 1600 | + max_turns=1, |
| 1601 | + ) |
| 1602 | + state.approve(approval_item) |
| 1603 | + |
| 1604 | + generated_items: list[RunItem] = [] |
| 1605 | + |
| 1606 | + # Execute approved tools |
| 1607 | + await AgentRunner._execute_approved_tools_static( |
| 1608 | + agent=agent, |
| 1609 | + interruptions=[approval_item], |
| 1610 | + context_wrapper=context_wrapper, |
| 1611 | + generated_items=generated_items, |
| 1612 | + run_config=RunConfig(), |
| 1613 | + hooks=RunHooks(), |
| 1614 | + ) |
| 1615 | + |
| 1616 | + # Should add error message about tool not found |
| 1617 | + assert len(generated_items) == 1 |
| 1618 | + assert isinstance(generated_items[0], ToolCallOutputItem) |
| 1619 | + assert "not found" in generated_items[0].output.lower() |
| 1620 | + |
| 1621 | + |
| 1622 | +@pytest.mark.asyncio |
| 1623 | +async def test_execute_approved_tools_instance_method(): |
| 1624 | + """Test the instance method wrapper for _execute_approved_tools.""" |
| 1625 | + model = FakeModel() |
| 1626 | + tool_called = False |
| 1627 | + |
| 1628 | + async def test_tool() -> str: |
| 1629 | + nonlocal tool_called |
| 1630 | + tool_called = True |
| 1631 | + return "tool_result" |
| 1632 | + |
| 1633 | + tool = function_tool(test_tool, name_override="test_tool") |
| 1634 | + agent = Agent(name="TestAgent", model=model, tools=[tool]) |
| 1635 | + |
| 1636 | + tool_call = get_function_tool_call("test_tool", json.dumps({})) |
| 1637 | + assert isinstance(tool_call, ResponseFunctionToolCall) |
| 1638 | + |
| 1639 | + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) |
| 1640 | + |
| 1641 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 1642 | + state = RunState( |
| 1643 | + context=context_wrapper, |
| 1644 | + original_input="test", |
| 1645 | + starting_agent=agent, |
| 1646 | + max_turns=1, |
| 1647 | + ) |
| 1648 | + state.approve(approval_item) |
| 1649 | + |
| 1650 | + generated_items: list[RunItem] = [] |
| 1651 | + |
| 1652 | + # Create an AgentRunner instance and use the instance method |
| 1653 | + runner = AgentRunner() |
| 1654 | + await runner._execute_approved_tools( |
| 1655 | + agent=agent, |
| 1656 | + interruptions=[approval_item], |
| 1657 | + context_wrapper=context_wrapper, |
| 1658 | + generated_items=generated_items, |
| 1659 | + run_config=RunConfig(), |
| 1660 | + hooks=RunHooks(), |
| 1661 | + ) |
| 1662 | + |
| 1663 | + # Tool should have been called |
| 1664 | + assert tool_called is True |
| 1665 | + assert len(generated_items) == 1 |
| 1666 | + assert isinstance(generated_items[0], ToolCallOutputItem) |
| 1667 | + assert generated_items[0].output == "tool_result" |
0 commit comments