Skip to content

Commit e92c62e

Browse files
committed
fix: update RunState with current turn persisted item tracking
1 parent 39d2b04 commit e92c62e

File tree

2 files changed

+60
-28
lines changed

2 files changed

+60
-28
lines changed

src/agents/run.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,6 @@ async def run(
808808
# If resuming from an interrupted state, execute approved tools first
809809
if is_resumed_state and run_state is not None and run_state._current_step is not None:
810810
if isinstance(run_state._current_step, NextStepInterruption):
811-
# Track items before executing approved tools
812-
items_before_execution = len(generated_items)
813-
814811
# We're resuming from an interruption - execute approved tools
815812
await self._execute_approved_tools(
816813
agent=current_agent,
@@ -821,14 +818,9 @@ async def run(
821818
hooks=hooks,
822819
)
823820

824-
# Save the newly executed tool outputs to the session
825-
new_tool_outputs: list[RunItem] = [
826-
item
827-
for item in generated_items[items_before_execution:]
828-
if item.type == "tool_call_output_item"
829-
]
830-
if new_tool_outputs and session is not None:
831-
await self._save_result_to_session(session, [], new_tool_outputs)
821+
# Save new items (counter tracks what's already saved)
822+
if session is not None:
823+
await self._save_result_to_session(session, [], generated_items, run_state)
832824

833825
# Clear the current step since we've handled it
834826
run_state._current_step = None
@@ -858,6 +850,9 @@ async def run(
858850
current_span.span_data.tools = [t.name for t in all_tools]
859851

860852
current_turn += 1
853+
if run_state is not None:
854+
run_state._current_turn_persisted_item_count = 0
855+
861856
if current_turn > max_turns:
862857
_error_tracing.attach_error_to_span(
863858
current_span,
@@ -947,7 +942,7 @@ async def run(
947942
for guardrail_result in input_guardrail_results
948943
):
949944
await self._save_result_to_session(
950-
session, [], turn_result.new_step_items
945+
session, [], turn_result.new_step_items, run_state
951946
)
952947

953948
return result
@@ -979,7 +974,7 @@ async def run(
979974
for guardrail_result in input_guardrail_results
980975
):
981976
await self._save_result_to_session(
982-
session, [], turn_result.new_step_items
977+
session, [], turn_result.new_step_items, run_state
983978
)
984979
else:
985980
raise AgentsException(
@@ -1331,9 +1326,6 @@ async def _start_streaming(
13311326
# If resuming from an interrupted state, execute approved tools first
13321327
if run_state is not None and run_state._current_step is not None:
13331328
if isinstance(run_state._current_step, NextStepInterruption):
1334-
# Track items before executing approved tools
1335-
items_before_execution = len(streamed_result.new_items)
1336-
13371329
# We're resuming from an interruption - execute approved tools
13381330
await cls._execute_approved_tools_static(
13391331
agent=current_agent,
@@ -1344,14 +1336,11 @@ async def _start_streaming(
13441336
hooks=hooks,
13451337
)
13461338

1347-
# Save the newly executed tool outputs to the session
1348-
new_tool_outputs: list[RunItem] = [
1349-
item
1350-
for item in streamed_result.new_items[items_before_execution:]
1351-
if item.type == "tool_call_output_item"
1352-
]
1353-
if new_tool_outputs and session is not None:
1354-
await cls._save_result_to_session(session, [], new_tool_outputs)
1339+
# Save new items (counter tracks what's already saved)
1340+
if session is not None:
1341+
await cls._save_result_to_session(
1342+
session, [], streamed_result.new_items, run_state
1343+
)
13551344

13561345
# Clear the current step since we've handled it
13571346
run_state._current_step = None
@@ -1390,6 +1379,8 @@ async def _start_streaming(
13901379
current_span.span_data.tools = tool_names
13911380
current_turn += 1
13921381
streamed_result.current_turn = current_turn
1382+
if run_state is not None:
1383+
run_state._current_turn_persisted_item_count = 0
13931384

13941385
if current_turn > max_turns:
13951386
_error_tracing.attach_error_to_span(
@@ -1495,7 +1486,7 @@ async def _start_streaming(
14951486
)
14961487
if should_skip_session_save is False:
14971488
await AgentRunner._save_result_to_session(
1498-
session, [], turn_result.new_step_items
1489+
session, [], turn_result.new_step_items, run_state
14991490
)
15001491

15011492
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1514,7 +1505,7 @@ async def _start_streaming(
15141505
)
15151506
if should_skip_session_save is False:
15161507
await AgentRunner._save_result_to_session(
1517-
session, [], turn_result.new_step_items
1508+
session, [], turn_result.new_step_items, run_state
15181509
)
15191510

15201511
# Check for soft cancel after turn completion
@@ -2373,9 +2364,14 @@ async def _save_result_to_session(
23732364
session: Session | None,
23742365
original_input: str | list[TResponseInputItem],
23752366
new_items: list[RunItem],
2367+
run_state: RunState[Any] | None = None,
23762368
) -> None:
23772369
"""
2378-
Save the conversation turn to session.
2370+
Save the conversation turn to session with incremental tracking.
2371+
2372+
Uses run_state._current_turn_persisted_item_count to track which items
2373+
have already been persisted, allowing partial saves within a turn.
2374+
23792375
It does not account for any filtering or modification performed by
23802376
`RunConfig.session_input_callback`.
23812377
"""
@@ -2385,13 +2381,34 @@ async def _save_result_to_session(
23852381
# Convert original input to list format if needed
23862382
input_list = ItemHelpers.input_to_new_input_list(original_input)
23872383

2384+
# Track which items have already been persisted this turn
2385+
already_persisted = 0
2386+
if run_state is not None:
2387+
already_persisted = run_state._current_turn_persisted_item_count
2388+
2389+
# Only save items that haven't been persisted yet
2390+
new_run_items = new_items[already_persisted:]
2391+
23882392
# Convert new items to input format
2389-
new_items_as_input = [item.to_input_item() for item in new_items]
2393+
new_items_as_input = [item.to_input_item() for item in new_run_items]
23902394

23912395
# Save all items from this turn
23922396
items_to_save = input_list + new_items_as_input
2397+
2398+
if len(items_to_save) == 0:
2399+
# Update counter even if nothing to save
2400+
if run_state is not None:
2401+
run_state._current_turn_persisted_item_count = already_persisted + len(
2402+
new_run_items
2403+
)
2404+
return
2405+
23932406
await session.add_items(items_to_save)
23942407

2408+
# Update the counter after successful save
2409+
if run_state is not None:
2410+
run_state._current_turn_persisted_item_count = already_persisted + len(new_run_items)
2411+
23952412
@staticmethod
23962413
async def _input_guardrail_tripwire_triggered_for_stream(
23972414
streamed_result: RunResultStreaming,

src/agents/run_state.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51+
_current_turn_persisted_item_count: int = 0
52+
"""Tracks how many generated run items from this turn were already persisted to session.
53+
54+
When saving to session, we slice off only new entries. When a turn is interrupted
55+
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56+
continuing so pending tool outputs still get stored.
57+
"""
58+
5159
_current_agent: TAgent | None = None
5260
"""The agent currently handling the conversation."""
5361

@@ -337,6 +345,7 @@ def to_json(self) -> dict[str, Any]:
337345
if self._last_processed_response
338346
else None
339347
)
348+
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
340349
result["trace"] = None
341350

342351
return result
@@ -571,6 +580,9 @@ async def from_string(
571580
)
572581

573582
state._current_turn = state_json["currentTurn"]
583+
state._current_turn_persisted_item_count = state_json.get(
584+
"currentTurnPersistedItemCount", 0
585+
)
574586

575587
# Reconstruct model responses
576588
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -676,6 +688,9 @@ async def from_json(
676688
)
677689

678690
state._current_turn = state_json["currentTurn"]
691+
state._current_turn_persisted_item_count = state_json.get(
692+
"currentTurnPersistedItemCount", 0
693+
)
679694

680695
# Reconstruct model responses
681696
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

0 commit comments

Comments
 (0)