Skip to content

Commit a228bdf

Browse files
Add suggested corrections
1 parent bc352cc commit a228bdf

File tree

5 files changed

+204
-25
lines changed

5 files changed

+204
-25
lines changed

src/agents/items.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import abc
44
import weakref
55
from dataclasses import dataclass, field
6-
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union
6+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast
77

88
import pydantic
99
from openai.types.responses import (
@@ -92,15 +92,22 @@ class RunItemBase(Generic[T], abc.ABC):
9292
)
9393

9494
def __post_init__(self) -> None:
95-
# Store the producing agent weakly to avoid keeping it alive after the run.
95+
# Store a weak reference so we can release the strong reference later if desired.
9696
self._agent_ref = weakref.ref(self.agent)
97-
object.__delattr__(self, "agent")
9897

9998
def __getattr__(self, name: str) -> Any:
10099
if name == "agent":
101100
return self._agent_ref() if self._agent_ref else None
102101
raise AttributeError(name)
103102

103+
def release_agent(self) -> None:
104+
"""Release the strong reference to the agent while keeping a weak reference."""
105+
if "agent" not in self.__dict__:
106+
return
107+
agent = self.__dict__["agent"]
108+
self._agent_ref = weakref.ref(agent) if agent is not None else None
109+
object.__delattr__(self, "agent")
110+
104111
def to_input_item(self) -> TResponseInputItem:
105112
"""Converts this item into an input item suitable for passing to the model."""
106113
if isinstance(self.raw_item, dict):
@@ -161,11 +168,9 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]):
161168

162169
def __post_init__(self) -> None:
163170
super().__post_init__()
164-
# Handoff metadata should not hold strong references to the agents either.
171+
# Maintain weak references so downstream code can release the strong references when safe.
165172
self._source_agent_ref = weakref.ref(self.source_agent)
166173
self._target_agent_ref = weakref.ref(self.target_agent)
167-
object.__delattr__(self, "source_agent")
168-
object.__delattr__(self, "target_agent")
169174

170175
def __getattr__(self, name: str) -> Any:
171176
if name == "source_agent":
@@ -174,6 +179,17 @@ def __getattr__(self, name: str) -> Any:
174179
return self._target_agent_ref() if self._target_agent_ref else None
175180
return super().__getattr__(name)
176181

182+
def release_agent(self) -> None:
183+
super().release_agent()
184+
if "source_agent" in self.__dict__:
185+
source_agent = self.__dict__["source_agent"]
186+
self._source_agent_ref = weakref.ref(source_agent) if source_agent is not None else None
187+
object.__delattr__(self, "source_agent")
188+
if "target_agent" in self.__dict__:
189+
target_agent = self.__dict__["target_agent"]
190+
self._target_agent_ref = weakref.ref(target_agent) if target_agent is not None else None
191+
object.__delattr__(self, "target_agent")
192+
177193

178194
ToolCallItemTypes: TypeAlias = Union[
179195
ResponseFunctionToolCall,
@@ -184,12 +200,13 @@ def __getattr__(self, name: str) -> Any:
184200
ResponseCodeInterpreterToolCall,
185201
ImageGenerationCall,
186202
LocalShellCall,
203+
dict[str, Any],
187204
]
188205
"""A type that represents a tool call item."""
189206

190207

191208
@dataclass
192-
class ToolCallItem(RunItemBase[ToolCallItemTypes]):
209+
class ToolCallItem(RunItemBase[Any]):
193210
"""Represents a tool call e.g. a function call or computer action call."""
194211

195212
raw_item: ToolCallItemTypes
@@ -198,13 +215,19 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]):
198215
type: Literal["tool_call_item"] = "tool_call_item"
199216

200217

218+
ToolCallOutputTypes: TypeAlias = Union[
219+
FunctionCallOutput,
220+
ComputerCallOutput,
221+
LocalShellCallOutput,
222+
dict[str, Any],
223+
]
224+
225+
201226
@dataclass
202-
class ToolCallOutputItem(
203-
RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]]
204-
):
227+
class ToolCallOutputItem(RunItemBase[Any]):
205228
"""Represents the output of a tool call."""
206229

207-
raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput
230+
raw_item: ToolCallOutputTypes
208231
"""The raw item from the model."""
209232

210233
output: Any
@@ -214,6 +237,25 @@ class ToolCallOutputItem(
214237

215238
type: Literal["tool_call_output_item"] = "tool_call_output_item"
216239

240+
def to_input_item(self) -> TResponseInputItem:
241+
"""Converts the tool output into an input item for the next model turn.
242+
243+
Hosted tool outputs (e.g. shell/apply_patch) carry a `status` field for the SDK's
244+
book-keeping, but the Responses API does not yet accept that parameter. Strip it from the
245+
payload we send back to the model while keeping the original raw item intact.
246+
"""
247+
248+
if isinstance(self.raw_item, dict):
249+
payload = dict(self.raw_item)
250+
payload_type = payload.get("type")
251+
if payload_type == "shell_call_output":
252+
payload.pop("status", None)
253+
payload.pop("shell_output", None)
254+
payload.pop("provider_data", None)
255+
return cast(TResponseInputItem, payload)
256+
257+
return super().to_input_item()
258+
217259

218260
@dataclass
219261
class ReasoningItem(RunItemBase[ResponseReasoningItem]):

src/agents/result.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
import weakref
56
from collections.abc import AsyncIterator
67
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -74,6 +75,32 @@ class RunResultBase(abc.ABC):
7475
def last_agent(self) -> Agent[Any]:
7576
"""The last agent that was run."""
7677

78+
def release_agents(self) -> None:
79+
"""
80+
Release strong references to agents held by this result. After calling this method,
81+
accessing `item.agent` or `last_agent` may return `None` if the agent has been garbage
82+
collected. Callers can use this when they are done inspecting the result and want to
83+
eagerly drop any associated agent graph.
84+
"""
85+
for item in self.new_items:
86+
release = getattr(item, "release_agent", None)
87+
if callable(release):
88+
release()
89+
self._release_last_agent_reference()
90+
91+
def __del__(self) -> None:
92+
try:
93+
# Fall back to releasing agents automatically in case the caller never invoked
94+
# `release_agents()` explicitly. This keeps the no-leak guarantee confirmed by tests.
95+
self.release_agents()
96+
except Exception:
97+
# Avoid raising from __del__.
98+
pass
99+
100+
@abc.abstractmethod
101+
def _release_last_agent_reference(self) -> None:
102+
"""Release stored agent reference specific to the concrete result type."""
103+
77104
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T:
78105
"""A convenience method to cast the final output to a specific type. By default, the cast
79106
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a
@@ -111,11 +138,33 @@ def last_response_id(self) -> str | None:
111138
@dataclass
112139
class RunResult(RunResultBase):
113140
_last_agent: Agent[Any]
141+
_last_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
142+
init=False,
143+
repr=False,
144+
default=None,
145+
)
146+
147+
def __post_init__(self) -> None:
148+
self._last_agent_ref = weakref.ref(self._last_agent)
114149

115150
@property
116151
def last_agent(self) -> Agent[Any]:
117152
"""The last agent that was run."""
118-
return self._last_agent
153+
agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent"))
154+
if agent is not None:
155+
return agent
156+
if self._last_agent_ref:
157+
agent = self._last_agent_ref()
158+
if agent is not None:
159+
return agent
160+
raise AgentsException("Last agent reference is no longer available.")
161+
162+
def _release_last_agent_reference(self) -> None:
163+
agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent"))
164+
if agent is None:
165+
return
166+
self._last_agent_ref = weakref.ref(agent)
167+
object.__delattr__(self, "_last_agent")
119168

120169
def __str__(self) -> str:
121170
return pretty_print_result(self)
@@ -150,6 +199,12 @@ class RunResultStreaming(RunResultBase):
150199
is_complete: bool = False
151200
"""Whether the agent has finished running."""
152201

202+
_current_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
203+
init=False,
204+
repr=False,
205+
default=None,
206+
)
207+
153208
# Queues that the background run_loop writes to
154209
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
155210
default_factory=asyncio.Queue, repr=False
@@ -167,12 +222,29 @@ class RunResultStreaming(RunResultBase):
167222
# Soft cancel state
168223
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
169224

225+
def __post_init__(self) -> None:
226+
self._current_agent_ref = weakref.ref(self.current_agent)
227+
170228
@property
171229
def last_agent(self) -> Agent[Any]:
172230
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
173231
is only available after the agent run is complete.
174232
"""
175-
return self.current_agent
233+
agent = cast("Agent[Any] | None", self.__dict__.get("current_agent"))
234+
if agent is not None:
235+
return agent
236+
if self._current_agent_ref:
237+
agent = self._current_agent_ref()
238+
if agent is not None:
239+
return agent
240+
raise AgentsException("Last agent reference is no longer available.")
241+
242+
def _release_last_agent_reference(self) -> None:
243+
agent = cast("Agent[Any] | None", self.__dict__.get("current_agent"))
244+
if agent is None:
245+
return
246+
self._current_agent_ref = weakref.ref(agent)
247+
object.__delattr__(self, "current_agent")
176248

177249
def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None:
178250
"""Cancel the streaming run.

tests/test_agent_memory_leak.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ def _make_message(text: str) -> ResponseOutputMessage:
2323
@pytest.mark.asyncio
2424
async def test_agent_is_released_after_run() -> None:
2525
fake_model = FakeModel(initial_output=[_make_message("Paris")])
26-
agent = Agent(name="leaker", instructions="Answer questions.", model=fake_model)
26+
agent = Agent(name="leak-test-agent", instructions="Answer questions.", model=fake_model)
2727
agent_ref = weakref.ref(agent)
2828

29+
# Running the agent should not leave behind strong references once the result goes out of scope.
2930
await Runner.run(agent, "What is the capital of France?")
3031

3132
del agent

tests/test_items_helpers.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import gc
34
import json
45

56
from openai.types.responses.response_computer_tool_call import (
@@ -57,16 +58,18 @@ def make_message(
5758

5859
def test_extract_last_content_of_text_message() -> None:
5960
# Build a message containing two text segments.
60-
content1 = ResponseOutputText(annotations=[], text="Hello ", type="output_text")
61-
content2 = ResponseOutputText(annotations=[], text="world!", type="output_text")
61+
content1 = ResponseOutputText(annotations=[], text="Hello ", type="output_text", logprobs=[])
62+
content2 = ResponseOutputText(annotations=[], text="world!", type="output_text", logprobs=[])
6263
message = make_message([content1, content2])
6364
# Helpers should yield the last segment's text.
6465
assert ItemHelpers.extract_last_content(message) == "world!"
6566

6667

6768
def test_extract_last_content_of_refusal_message() -> None:
6869
# Build a message whose last content entry is a refusal.
69-
content1 = ResponseOutputText(annotations=[], text="Before refusal", type="output_text")
70+
content1 = ResponseOutputText(
71+
annotations=[], text="Before refusal", type="output_text", logprobs=[]
72+
)
7073
refusal = ResponseOutputRefusal(refusal="I cannot do that", type="refusal")
7174
message = make_message([content1, refusal])
7275
# Helpers should extract the refusal string when last content is a refusal.
@@ -87,8 +90,8 @@ def test_extract_last_content_non_message_returns_empty() -> None:
8790

8891
def test_extract_last_text_returns_text_only() -> None:
8992
# A message whose last segment is text yields the text.
90-
first_text = ResponseOutputText(annotations=[], text="part1", type="output_text")
91-
second_text = ResponseOutputText(annotations=[], text="part2", type="output_text")
93+
first_text = ResponseOutputText(annotations=[], text="part1", type="output_text", logprobs=[])
94+
second_text = ResponseOutputText(annotations=[], text="part2", type="output_text", logprobs=[])
9295
message = make_message([first_text, second_text])
9396
assert ItemHelpers.extract_last_text(message) == "part2"
9497
# Whereas when last content is a refusal, extract_last_text returns None.
@@ -116,9 +119,9 @@ def test_input_to_new_input_list_deep_copies_lists() -> None:
116119
def test_text_message_output_concatenates_text_segments() -> None:
117120
# Build a message with both text and refusal segments, only text segments are concatenated.
118121
pieces: list[ResponseOutputText | ResponseOutputRefusal] = []
119-
pieces.append(ResponseOutputText(annotations=[], text="a", type="output_text"))
122+
pieces.append(ResponseOutputText(annotations=[], text="a", type="output_text", logprobs=[]))
120123
pieces.append(ResponseOutputRefusal(refusal="denied", type="refusal"))
121-
pieces.append(ResponseOutputText(annotations=[], text="b", type="output_text"))
124+
pieces.append(ResponseOutputText(annotations=[], text="b", type="output_text", logprobs=[]))
122125
message = make_message(pieces)
123126
# Wrap into MessageOutputItem to feed into text_message_output.
124127
item = MessageOutputItem(agent=Agent(name="test"), raw_item=message)
@@ -131,8 +134,12 @@ def test_text_message_outputs_across_list_of_runitems() -> None:
131134
that only MessageOutputItem instances contribute any text. The non-message
132135
(ReasoningItem) should be ignored by Helpers.text_message_outputs.
133136
"""
134-
message1 = make_message([ResponseOutputText(annotations=[], text="foo", type="output_text")])
135-
message2 = make_message([ResponseOutputText(annotations=[], text="bar", type="output_text")])
137+
message1 = make_message(
138+
[ResponseOutputText(annotations=[], text="foo", type="output_text", logprobs=[])]
139+
)
140+
message2 = make_message(
141+
[ResponseOutputText(annotations=[], text="bar", type="output_text", logprobs=[])]
142+
)
136143
item1: RunItem = MessageOutputItem(agent=Agent(name="test"), raw_item=message1)
137144
item2: RunItem = MessageOutputItem(agent=Agent(name="test"), raw_item=message2)
138145
# Create a non-message run item of a different type, e.g., a reasoning trace.
@@ -142,6 +149,19 @@ def test_text_message_outputs_across_list_of_runitems() -> None:
142149
assert ItemHelpers.text_message_outputs([item1, non_message_item, item2]) == "foobar"
143150

144151

152+
def test_message_output_item_retains_agent_until_release() -> None:
153+
# Construct the run item with an inline agent to ensure the run item keeps a strong reference.
154+
message = make_message([ResponseOutputText(annotations=[], text="hello", type="output_text")])
155+
item = MessageOutputItem(agent=Agent(name="inline"), raw_item=message)
156+
assert item.agent is not None
157+
assert item.agent.name == "inline"
158+
159+
# After explicitly releasing, the weak reference should drop once GC runs.
160+
item.release_agent()
161+
gc.collect()
162+
assert item.agent is None
163+
164+
145165
def test_tool_call_output_item_constructs_function_call_output_dict():
146166
# Build a simple ResponseFunctionToolCall.
147167
call = ResponseFunctionToolCall(
@@ -171,7 +191,9 @@ def test_tool_call_output_item_constructs_function_call_output_dict():
171191

172192
def test_to_input_items_for_message() -> None:
173193
"""An output message should convert into an input dict matching the message's own structure."""
174-
content = ResponseOutputText(annotations=[], text="hello world", type="output_text")
194+
content = ResponseOutputText(
195+
annotations=[], text="hello world", type="output_text", logprobs=[]
196+
)
175197
message = ResponseOutputMessage(
176198
id="m1", content=[content], role="assistant", status="completed", type="message"
177199
)
@@ -184,6 +206,7 @@ def test_to_input_items_for_message() -> None:
184206
"content": [
185207
{
186208
"annotations": [],
209+
"logprobs": [],
187210
"text": "hello world",
188211
"type": "output_text",
189212
}
@@ -305,6 +328,7 @@ def test_input_to_new_input_list_copies_the_ones_produced_by_pydantic() -> None:
305328
type="output_text",
306329
text="Hey, what's up?",
307330
annotations=[],
331+
logprobs=[],
308332
)
309333
],
310334
role="assistant",

0 commit comments

Comments
 (0)