Skip to content

Commit 43909a8

Browse files
committed
use ToolCallItemTypes's args better maintainability
1 parent 84f55ba commit 43909a8

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

src/agents/run.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,12 @@
44
import copy
55
import inspect
66
from dataclasses import dataclass, field
7-
from typing import Any, Generic, cast
7+
from typing import Any, Generic, cast, get_args
88

99
from openai.types.responses import (
1010
ResponseCompletedEvent,
11-
ResponseComputerToolCall,
12-
ResponseFileSearchToolCall,
13-
ResponseFunctionToolCall,
1411
ResponseOutputItemAddedEvent,
1512
)
16-
from openai.types.responses.response_code_interpreter_tool_call import (
17-
ResponseCodeInterpreterToolCall,
18-
)
19-
from openai.types.responses.response_output_item import (
20-
ImageGenerationCall,
21-
LocalShellCall,
22-
McpCall,
23-
)
2413
from openai.types.responses.response_prompt_param import (
2514
ResponsePromptParam,
2615
)
@@ -55,7 +44,14 @@
5544
OutputGuardrailResult,
5645
)
5746
from .handoffs import Handoff, HandoffInputFilter, handoff
58-
from .items import ItemHelpers, ModelResponse, RunItem, ToolCallItem, TResponseInputItem
47+
from .items import (
48+
ItemHelpers,
49+
ModelResponse,
50+
RunItem,
51+
ToolCallItem,
52+
ToolCallItemTypes,
53+
TResponseInputItem,
54+
)
5955
from .lifecycle import RunHooks
6056
from .logger import logger
6157
from .memory import Session
@@ -922,18 +918,7 @@ async def _run_single_turn_streamed(
922918
if isinstance(event, ResponseOutputItemAddedEvent):
923919
output_item = event.item
924920

925-
if isinstance(
926-
output_item,
927-
(
928-
ResponseFunctionToolCall,
929-
ResponseFileSearchToolCall,
930-
ResponseComputerToolCall,
931-
ResponseCodeInterpreterToolCall,
932-
ImageGenerationCall,
933-
LocalShellCall,
934-
McpCall,
935-
),
936-
):
921+
if isinstance(output_item, _TOOL_CALL_TYPES):
937922
call_id = getattr(output_item, "call_id", getattr(output_item, "id", None))
938923

939924
if call_id not in emitted_tool_call_ids:
@@ -1310,3 +1295,5 @@ async def _save_result_to_session(
13101295

13111296

13121297
DEFAULT_AGENT_RUNNER = AgentRunner()
1298+
1299+
_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)

0 commit comments

Comments
 (0)