|
4 | 4 | import copy
|
5 | 5 | import inspect
|
6 | 6 | from dataclasses import dataclass, field
|
7 |
| -from typing import Any, Generic, cast |
| 7 | +from typing import Any, Generic, cast, get_args |
8 | 8 |
|
9 | 9 | from openai.types.responses import (
|
10 | 10 | ResponseCompletedEvent,
|
11 |
| - ResponseComputerToolCall, |
12 |
| - ResponseFileSearchToolCall, |
13 |
| - ResponseFunctionToolCall, |
14 | 11 | ResponseOutputItemAddedEvent,
|
15 | 12 | )
|
16 |
| -from openai.types.responses.response_code_interpreter_tool_call import ( |
17 |
| - ResponseCodeInterpreterToolCall, |
18 |
| -) |
19 |
| -from openai.types.responses.response_output_item import ( |
20 |
| - ImageGenerationCall, |
21 |
| - LocalShellCall, |
22 |
| - McpCall, |
23 |
| -) |
24 | 13 | from openai.types.responses.response_prompt_param import (
|
25 | 14 | ResponsePromptParam,
|
26 | 15 | )
|
|
55 | 44 | OutputGuardrailResult,
|
56 | 45 | )
|
57 | 46 | from .handoffs import Handoff, HandoffInputFilter, handoff
|
58 |
| -from .items import ItemHelpers, ModelResponse, RunItem, ToolCallItem, TResponseInputItem |
| 47 | +from .items import ( |
| 48 | + ItemHelpers, |
| 49 | + ModelResponse, |
| 50 | + RunItem, |
| 51 | + ToolCallItem, |
| 52 | + ToolCallItemTypes, |
| 53 | + TResponseInputItem, |
| 54 | +) |
59 | 55 | from .lifecycle import RunHooks
|
60 | 56 | from .logger import logger
|
61 | 57 | from .memory import Session
|
@@ -922,18 +918,7 @@ async def _run_single_turn_streamed(
|
922 | 918 | if isinstance(event, ResponseOutputItemAddedEvent):
|
923 | 919 | output_item = event.item
|
924 | 920 |
|
925 |
| - if isinstance( |
926 |
| - output_item, |
927 |
| - ( |
928 |
| - ResponseFunctionToolCall, |
929 |
| - ResponseFileSearchToolCall, |
930 |
| - ResponseComputerToolCall, |
931 |
| - ResponseCodeInterpreterToolCall, |
932 |
| - ImageGenerationCall, |
933 |
| - LocalShellCall, |
934 |
| - McpCall, |
935 |
| - ), |
936 |
| - ): |
| 921 | + if isinstance(output_item, _TOOL_CALL_TYPES): |
937 | 922 | call_id = getattr(output_item, "call_id", getattr(output_item, "id", None))
|
938 | 923 |
|
939 | 924 | if call_id not in emitted_tool_call_ids:
|
@@ -1310,3 +1295,5 @@ async def _save_result_to_session(
|
1310 | 1295 |
|
1311 | 1296 |
|
1312 | 1297 | DEFAULT_AGENT_RUNNER = AgentRunner()
|
| 1298 | + |
| 1299 | +_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) |
0 commit comments