Skip to content

Commit 7593e92

Browse files
committed
refactor
1 parent a7089a4 commit 7593e92

File tree

7 files changed

+236
-268
lines changed

7 files changed

+236
-268
lines changed

src/agents/run_internal/items.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,6 @@ def drop_orphan_function_calls(items: list[TResponseInputItem]) -> list[TRespons
4646
not replay stale tool calls.
4747
"""
4848

49-
def _completed_call_ids(payload: list[TResponseInputItem]) -> set[str]:
50-
completed: set[str] = set()
51-
for entry in payload:
52-
if not isinstance(entry, dict):
53-
continue
54-
item_type = entry.get("type")
55-
if item_type not in ("function_call_output", "function_call_result"):
56-
continue
57-
call_id = entry.get("call_id") or entry.get("callId")
58-
if call_id and isinstance(call_id, str):
59-
completed.add(call_id)
60-
return completed
61-
6249
completed_call_ids = _completed_call_ids(items)
6350

6451
filtered: list[TResponseInputItem] = []
@@ -77,19 +64,7 @@ def _completed_call_ids(payload: list[TResponseInputItem]) -> set[str]:
7764

7865
def ensure_input_item_format(item: TResponseInputItem) -> TResponseInputItem:
7966
"""Ensure a single item is normalized for model input (function_call_output, snake_case)."""
80-
81-
def _coerce_dict(value: TResponseInputItem) -> dict[str, Any] | None:
82-
"""Convert dataclass/Pydantic items into plain dicts when possible."""
83-
if isinstance(value, dict):
84-
return dict(value)
85-
if hasattr(value, "model_dump"):
86-
try:
87-
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
88-
except Exception:
89-
return None
90-
return None
91-
92-
coerced = _coerce_dict(item)
67+
coerced = _coerce_to_dict(item)
9368
if coerced is None:
9469
return item
9570

@@ -100,17 +75,6 @@ def _coerce_dict(value: TResponseInputItem) -> dict[str, Any] | None:
10075
def normalize_input_items_for_api(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
10176
"""Normalize input items for API submission and strip provider data for downstream services."""
10277

103-
def _coerce_to_dict(value: TResponseInputItem) -> dict[str, Any] | None:
104-
"""Convert model items to dicts so fields can be renamed and sanitized."""
105-
if isinstance(value, dict):
106-
return dict(value)
107-
if hasattr(value, "model_dump"):
108-
try:
109-
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
110-
except Exception:
111-
return None
112-
return None
113-
11478
normalized: list[TResponseInputItem] = []
11579
for item in items:
11680
coerced = _coerce_to_dict(item)
@@ -223,17 +187,33 @@ def extract_mcp_request_id_from_run(mcp_run: Any) -> str | None:
223187
return candidate if isinstance(candidate, str) else None
224188

225189

226-
__all__ = [
227-
"REJECTION_MESSAGE",
228-
"copy_input_items",
229-
"drop_orphan_function_calls",
230-
"ensure_input_item_format",
231-
"normalize_input_items_for_api",
232-
"fingerprint_input_item",
233-
"deduplicate_input_items",
234-
"function_rejection_item",
235-
"shell_rejection_item",
236-
"apply_patch_rejection_item",
237-
"extract_mcp_request_id",
238-
"extract_mcp_request_id_from_run",
239-
]
190+
# --------------------------
191+
# Private helpers
192+
# --------------------------
193+
194+
195+
def _completed_call_ids(payload: list[TResponseInputItem]) -> set[str]:
196+
"""Return the call ids that already have outputs."""
197+
completed: set[str] = set()
198+
for entry in payload:
199+
if not isinstance(entry, dict):
200+
continue
201+
item_type = entry.get("type")
202+
if item_type not in ("function_call_output", "function_call_result"):
203+
continue
204+
call_id = entry.get("call_id") or entry.get("callId")
205+
if call_id and isinstance(call_id, str):
206+
completed.add(call_id)
207+
return completed
208+
209+
210+
def _coerce_to_dict(value: TResponseInputItem) -> dict[str, Any] | None:
211+
"""Convert model items to dicts so fields can be renamed and sanitized."""
212+
if isinstance(value, dict):
213+
return dict(value)
214+
if hasattr(value, "model_dump"):
215+
try:
216+
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
217+
except Exception:
218+
return None
219+
return None

src/agents/run_internal/session_persistence.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -85,57 +85,19 @@ async def prepare_input_with_session(
8585
if not isinstance(combined, list):
8686
raise UserError("Session input callback must return a list of input items.")
8787

88-
def session_item_key(item: Any) -> str:
89-
try:
90-
if hasattr(item, "model_dump"):
91-
payload = item.model_dump(exclude_unset=True)
92-
elif isinstance(item, dict):
93-
payload = item
94-
else:
95-
payload = ensure_input_item_format(item)
96-
return json.dumps(payload, sort_keys=True, default=str)
97-
except Exception:
98-
return repr(item)
99-
100-
def build_reference_map(items: Sequence[Any]) -> dict[str, list[Any]]:
101-
refs: dict[str, list[Any]] = {}
102-
for item in items:
103-
key = session_item_key(item)
104-
refs.setdefault(key, []).append(item)
105-
return refs
106-
107-
def consume_reference(ref_map: dict[str, list[Any]], key: str, candidate: Any) -> bool:
108-
candidates = ref_map.get(key)
109-
if not candidates:
110-
return False
111-
for idx, existing in enumerate(candidates):
112-
if existing is candidate:
113-
candidates.pop(idx)
114-
if not candidates:
115-
ref_map.pop(key, None)
116-
return True
117-
return False
118-
119-
def build_frequency_map(items: Sequence[Any]) -> dict[str, int]:
120-
freq: dict[str, int] = {}
121-
for item in items:
122-
key = session_item_key(item)
123-
freq[key] = freq.get(key, 0) + 1
124-
return freq
125-
126-
history_refs = build_reference_map(history_for_callback)
127-
new_refs = build_reference_map(new_items_for_callback)
128-
history_counts = build_frequency_map(history_for_callback)
129-
new_counts = build_frequency_map(new_items_for_callback)
88+
history_refs = _build_reference_map(history_for_callback)
89+
new_refs = _build_reference_map(new_items_for_callback)
90+
history_counts = _build_frequency_map(history_for_callback)
91+
new_counts = _build_frequency_map(new_items_for_callback)
13092

13193
appended: list[Any] = []
13294
for item in combined:
133-
key = session_item_key(item)
134-
if consume_reference(new_refs, key, item):
95+
key = _session_item_key(item)
96+
if _consume_reference(new_refs, key, item):
13597
new_counts[key] = max(new_counts.get(key, 0) - 1, 0)
13698
appended.append(item)
13799
continue
138-
if consume_reference(history_refs, key, item):
100+
if _consume_reference(history_refs, key, item):
139101
history_counts[key] = max(history_counts.get(key, 0) - 1, 0)
140102
continue
141103
if history_counts.get(key, 0) > 0:
@@ -440,3 +402,54 @@ async def wait_for_session_cleanup(
440402
logger.debug(
441403
"Session cleanup verification exhausted attempts; targets may still linger temporarily"
442404
)
405+
406+
407+
# --------------------------
408+
# Private helpers
409+
# --------------------------
410+
411+
412+
def _session_item_key(item: Any) -> str:
413+
"""Return a stable representation of a session item for comparison."""
414+
try:
415+
if hasattr(item, "model_dump"):
416+
payload = item.model_dump(exclude_unset=True)
417+
elif isinstance(item, dict):
418+
payload = item
419+
else:
420+
payload = ensure_input_item_format(item)
421+
return json.dumps(payload, sort_keys=True, default=str)
422+
except Exception:
423+
return repr(item)
424+
425+
426+
def _build_reference_map(items: Sequence[Any]) -> dict[str, list[Any]]:
427+
"""Map serialized keys to the concrete session items used to build them."""
428+
refs: dict[str, list[Any]] = {}
429+
for item in items:
430+
key = _session_item_key(item)
431+
refs.setdefault(key, []).append(item)
432+
return refs
433+
434+
435+
def _consume_reference(ref_map: dict[str, list[Any]], key: str, candidate: Any) -> bool:
436+
"""Remove a specific candidate from a reference map when it is consumed."""
437+
candidates = ref_map.get(key)
438+
if not candidates:
439+
return False
440+
for idx, existing in enumerate(candidates):
441+
if existing is candidate:
442+
candidates.pop(idx)
443+
if not candidates:
444+
ref_map.pop(key, None)
445+
return True
446+
return False
447+
448+
449+
def _build_frequency_map(items: Sequence[Any]) -> dict[str, int]:
450+
"""Count how many times each serialized key appears in a collection."""
451+
freq: dict[str, int] = {}
452+
for item in items:
453+
key = _session_item_key(item)
454+
freq[key] = freq.get(key, 0) + 1
455+
return freq

src/agents/run_internal/tool_actions.py

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,6 @@ async def execute(
9090
"""Run a computer action, capturing a screenshot and notifying hooks."""
9191
computer = await resolve_computer(tool=action.computer_tool, run_context=context_wrapper)
9292
agent_hooks = agent.hooks
93-
output_func = (
94-
cls._get_screenshot_async(computer, action.tool_call)
95-
if hasattr(computer, "screenshot_async")
96-
else cls._get_screenshot_sync(computer, action.tool_call)
97-
)
9893
await asyncio.gather(
9994
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
10095
(
@@ -104,7 +99,7 @@ async def execute(
10499
),
105100
)
106101

107-
output = await output_func
102+
output = await cls._execute_action_and_capture(computer, action.tool_call)
108103

109104
await asyncio.gather(
110105
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
@@ -131,62 +126,40 @@ async def execute(
131126
)
132127

133128
@classmethod
134-
async def _get_screenshot_sync(
135-
cls,
136-
computer: Any,
137-
tool_call: ResponseComputerToolCall,
129+
async def _execute_action_and_capture(
130+
cls, computer: Any, tool_call: ResponseComputerToolCall
138131
) -> str:
139-
"""Execute the computer action for sync drivers and return the screenshot."""
140-
action = tool_call.action
141-
if isinstance(action, ActionClick):
142-
computer.click(action.x, action.y, action.button)
143-
elif isinstance(action, ActionDoubleClick):
144-
computer.double_click(action.x, action.y)
145-
elif isinstance(action, ActionDrag):
146-
computer.drag([(p.x, p.y) for p in action.path])
147-
elif isinstance(action, ActionKeypress):
148-
computer.keypress(action.keys)
149-
elif isinstance(action, ActionMove):
150-
computer.move(action.x, action.y)
151-
elif isinstance(action, ActionScreenshot):
152-
computer.screenshot()
153-
elif isinstance(action, ActionScroll):
154-
computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
155-
elif isinstance(action, ActionType):
156-
computer.type(action.text)
157-
elif isinstance(action, ActionWait):
158-
computer.wait()
132+
"""Execute the computer action (sync or async drivers) and return the screenshot."""
159133

160-
return cast(str, computer.screenshot())
134+
async def maybe_call(method_name: str, *args: Any) -> Any:
135+
method = getattr(computer, method_name, None)
136+
if method is None or not callable(method):
137+
raise ModelBehaviorError(f"Computer driver missing method {method_name}")
138+
result = method(*args)
139+
return await result if inspect.isawaitable(result) else result
161140

162-
@classmethod
163-
async def _get_screenshot_async(
164-
cls,
165-
computer: Any,
166-
tool_call: ResponseComputerToolCall,
167-
) -> str:
168-
"""Execute the computer action for async drivers and return the screenshot."""
169141
action = tool_call.action
170142
if isinstance(action, ActionClick):
171-
await computer.click(action.x, action.y, action.button)
143+
await maybe_call("click", action.x, action.y, action.button)
172144
elif isinstance(action, ActionDoubleClick):
173-
await computer.double_click(action.x, action.y)
145+
await maybe_call("double_click", action.x, action.y)
174146
elif isinstance(action, ActionDrag):
175-
await computer.drag([(p.x, p.y) for p in action.path])
147+
await maybe_call("drag", [(p.x, p.y) for p in action.path])
176148
elif isinstance(action, ActionKeypress):
177-
await computer.keypress(action.keys)
149+
await maybe_call("keypress", action.keys)
178150
elif isinstance(action, ActionMove):
179-
await computer.move(action.x, action.y)
151+
await maybe_call("move", action.x, action.y)
180152
elif isinstance(action, ActionScreenshot):
181-
await computer.screenshot()
153+
await maybe_call("screenshot")
182154
elif isinstance(action, ActionScroll):
183-
await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y)
155+
await maybe_call("scroll", action.x, action.y, action.scroll_x, action.scroll_y)
184156
elif isinstance(action, ActionType):
185-
await computer.type(action.text)
157+
await maybe_call("type", action.text)
186158
elif isinstance(action, ActionWait):
187-
await computer.wait()
159+
await maybe_call("wait")
188160

189-
return cast(str, await computer.screenshot())
161+
screenshot_result = await maybe_call("screenshot")
162+
return cast(str, screenshot_result)
190163

191164

192165
class LocalShellAction:

0 commit comments

Comments
 (0)