|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import abc |
4 | | -from dataclasses import dataclass |
| 4 | +import weakref |
| 5 | +from dataclasses import dataclass, field |
5 | 6 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union |
6 | 7 |
|
7 | 8 | import pydantic |
@@ -84,6 +85,22 @@ class RunItemBase(Generic[T], abc.ABC): |
84 | 85 | (i.e. `openai.types.responses.ResponseInputItemParam`). |
85 | 86 | """ |
86 | 87 |
|
| 88 | + _agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( |
| 89 | + init=False, |
| 90 | + repr=False, |
| 91 | + default=None, |
| 92 | + ) |
| 93 | + |
| 94 | + def __post_init__(self) -> None: |
| 95 | + # Store the producing agent weakly to avoid keeping it alive after the run. |
| 96 | + self._agent_ref = weakref.ref(self.agent) |
| 97 | + object.__delattr__(self, "agent") |
| 98 | + |
| 99 | + def __getattr__(self, name: str) -> Any: |
| 100 | + if name == "agent": |
| 101 | + return self._agent_ref() if self._agent_ref else None |
| 102 | + raise AttributeError(name) |
| 103 | + |
87 | 104 | def to_input_item(self) -> TResponseInputItem: |
88 | 105 | """Converts this item into an input item suitable for passing to the model.""" |
89 | 106 | if isinstance(self.raw_item, dict): |
@@ -131,6 +148,32 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): |
131 | 148 |
|
132 | 149 | type: Literal["handoff_output_item"] = "handoff_output_item" |
133 | 150 |
|
| 151 | + _source_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( |
| 152 | + init=False, |
| 153 | + repr=False, |
| 154 | + default=None, |
| 155 | + ) |
| 156 | + _target_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( |
| 157 | + init=False, |
| 158 | + repr=False, |
| 159 | + default=None, |
| 160 | + ) |
| 161 | + |
| 162 | + def __post_init__(self) -> None: |
| 163 | + super().__post_init__() |
| 164 | + # Handoff metadata should not hold strong references to the agents either. |
| 165 | + self._source_agent_ref = weakref.ref(self.source_agent) |
| 166 | + self._target_agent_ref = weakref.ref(self.target_agent) |
| 167 | + object.__delattr__(self, "source_agent") |
| 168 | + object.__delattr__(self, "target_agent") |
| 169 | + |
| 170 | + def __getattr__(self, name: str) -> Any: |
| 171 | + if name == "source_agent": |
| 172 | + return self._source_agent_ref() if self._source_agent_ref else None |
| 173 | + if name == "target_agent": |
| 174 | + return self._target_agent_ref() if self._target_agent_ref else None |
| 175 | + return super().__getattr__(name) |
| 176 | + |
134 | 177 |
|
135 | 178 | ToolCallItemTypes: TypeAlias = Union[ |
136 | 179 | ResponseFunctionToolCall, |
|
0 commit comments