Skip to content

Commit 164013f

Browse files
committed
fix: update RunState with current turn persisted item tracking
1 parent 466c306 commit 164013f

File tree

2 files changed

+60
-28
lines changed

2 files changed

+60
-28
lines changed

src/agents/run.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,6 @@ async def run(
795795
# If resuming from an interrupted state, execute approved tools first
796796
if is_resumed_state and run_state is not None and run_state._current_step is not None:
797797
if isinstance(run_state._current_step, NextStepInterruption):
798-
# Track items before executing approved tools
799-
items_before_execution = len(generated_items)
800-
801798
# We're resuming from an interruption - execute approved tools
802799
await self._execute_approved_tools(
803800
agent=current_agent,
@@ -808,14 +805,9 @@ async def run(
808805
hooks=hooks,
809806
)
810807

811-
# Save the newly executed tool outputs to the session
812-
new_tool_outputs: list[RunItem] = [
813-
item
814-
for item in generated_items[items_before_execution:]
815-
if item.type == "tool_call_output_item"
816-
]
817-
if new_tool_outputs and session is not None:
818-
await self._save_result_to_session(session, [], new_tool_outputs)
808+
# Save new items (counter tracks what's already saved)
809+
if session is not None:
810+
await self._save_result_to_session(session, [], generated_items, run_state)
819811

820812
# Clear the current step since we've handled it
821813
run_state._current_step = None
@@ -845,6 +837,9 @@ async def run(
845837
current_span.span_data.tools = [t.name for t in all_tools]
846838

847839
current_turn += 1
840+
if run_state is not None:
841+
run_state._current_turn_persisted_item_count = 0
842+
848843
if current_turn > max_turns:
849844
_error_tracing.attach_error_to_span(
850845
current_span,
@@ -934,7 +929,7 @@ async def run(
934929
for guardrail_result in input_guardrail_results
935930
):
936931
await self._save_result_to_session(
937-
session, [], turn_result.new_step_items
932+
session, [], turn_result.new_step_items, run_state
938933
)
939934

940935
return result
@@ -966,7 +961,7 @@ async def run(
966961
for guardrail_result in input_guardrail_results
967962
):
968963
await self._save_result_to_session(
969-
session, [], turn_result.new_step_items
964+
session, [], turn_result.new_step_items, run_state
970965
)
971966
else:
972967
raise AgentsException(
@@ -1318,9 +1313,6 @@ async def _start_streaming(
13181313
# If resuming from an interrupted state, execute approved tools first
13191314
if run_state is not None and run_state._current_step is not None:
13201315
if isinstance(run_state._current_step, NextStepInterruption):
1321-
# Track items before executing approved tools
1322-
items_before_execution = len(streamed_result.new_items)
1323-
13241316
# We're resuming from an interruption - execute approved tools
13251317
await cls._execute_approved_tools_static(
13261318
agent=current_agent,
@@ -1331,14 +1323,11 @@ async def _start_streaming(
13311323
hooks=hooks,
13321324
)
13331325

1334-
# Save the newly executed tool outputs to the session
1335-
new_tool_outputs: list[RunItem] = [
1336-
item
1337-
for item in streamed_result.new_items[items_before_execution:]
1338-
if item.type == "tool_call_output_item"
1339-
]
1340-
if new_tool_outputs and session is not None:
1341-
await cls._save_result_to_session(session, [], new_tool_outputs)
1326+
# Save new items (counter tracks what's already saved)
1327+
if session is not None:
1328+
await cls._save_result_to_session(
1329+
session, [], streamed_result.new_items, run_state
1330+
)
13421331

13431332
# Clear the current step since we've handled it
13441333
run_state._current_step = None
@@ -1377,6 +1366,8 @@ async def _start_streaming(
13771366
current_span.span_data.tools = tool_names
13781367
current_turn += 1
13791368
streamed_result.current_turn = current_turn
1369+
if run_state is not None:
1370+
run_state._current_turn_persisted_item_count = 0
13801371

13811372
if current_turn > max_turns:
13821373
_error_tracing.attach_error_to_span(
@@ -1482,7 +1473,7 @@ async def _start_streaming(
14821473
)
14831474
if should_skip_session_save is False:
14841475
await AgentRunner._save_result_to_session(
1485-
session, [], turn_result.new_step_items
1476+
session, [], turn_result.new_step_items, run_state
14861477
)
14871478

14881479
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1501,7 +1492,7 @@ async def _start_streaming(
15011492
)
15021493
if should_skip_session_save is False:
15031494
await AgentRunner._save_result_to_session(
1504-
session, [], turn_result.new_step_items
1495+
session, [], turn_result.new_step_items, run_state
15051496
)
15061497

15071498
# Check for soft cancel after turn completion
@@ -2360,9 +2351,14 @@ async def _save_result_to_session(
23602351
session: Session | None,
23612352
original_input: str | list[TResponseInputItem],
23622353
new_items: list[RunItem],
2354+
run_state: RunState[Any] | None = None,
23632355
) -> None:
23642356
"""
2365-
Save the conversation turn to session.
2357+
Save the conversation turn to session with incremental tracking.
2358+
2359+
Uses run_state._current_turn_persisted_item_count to track which items
2360+
have already been persisted, allowing partial saves within a turn.
2361+
23662362
It does not account for any filtering or modification performed by
23672363
`RunConfig.session_input_callback`.
23682364
"""
@@ -2372,13 +2368,34 @@ async def _save_result_to_session(
23722368
# Convert original input to list format if needed
23732369
input_list = ItemHelpers.input_to_new_input_list(original_input)
23742370

2371+
# Track which items have already been persisted this turn
2372+
already_persisted = 0
2373+
if run_state is not None:
2374+
already_persisted = run_state._current_turn_persisted_item_count
2375+
2376+
# Only save items that haven't been persisted yet
2377+
new_run_items = new_items[already_persisted:]
2378+
23752379
# Convert new items to input format
2376-
new_items_as_input = [item.to_input_item() for item in new_items]
2380+
new_items_as_input = [item.to_input_item() for item in new_run_items]
23772381

23782382
# Save all items from this turn
23792383
items_to_save = input_list + new_items_as_input
2384+
2385+
if len(items_to_save) == 0:
2386+
# Update counter even if nothing to save
2387+
if run_state is not None:
2388+
run_state._current_turn_persisted_item_count = already_persisted + len(
2389+
new_run_items
2390+
)
2391+
return
2392+
23802393
await session.add_items(items_to_save)
23812394

2395+
# Update the counter after successful save
2396+
if run_state is not None:
2397+
run_state._current_turn_persisted_item_count = already_persisted + len(new_run_items)
2398+
23822399
@staticmethod
23832400
async def _input_guardrail_tripwire_triggered_for_stream(
23842401
streamed_result: RunResultStreaming,

src/agents/run_state.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51+
_current_turn_persisted_item_count: int = 0
52+
"""Tracks how many generated run items from this turn were already persisted to session.
53+
54+
When saving to session, we slice off only new entries. When a turn is interrupted
55+
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56+
continuing so pending tool outputs still get stored.
57+
"""
58+
5159
_current_agent: TAgent | None = None
5260
"""The agent currently handling the conversation."""
5361

@@ -337,6 +345,7 @@ def to_json(self) -> dict[str, Any]:
337345
if self._last_processed_response
338346
else None
339347
)
348+
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
340349
result["trace"] = None
341350

342351
return result
@@ -571,6 +580,9 @@ async def from_string(
571580
)
572581

573582
state._current_turn = state_json["currentTurn"]
583+
state._current_turn_persisted_item_count = state_json.get(
584+
"currentTurnPersistedItemCount", 0
585+
)
574586

575587
# Reconstruct model responses
576588
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -676,6 +688,9 @@ async def from_json(
676688
)
677689

678690
state._current_turn = state_json["currentTurn"]
691+
state._current_turn_persisted_item_count = state_json.get(
692+
"currentTurnPersistedItemCount", 0
693+
)
679694

680695
# Reconstruct model responses
681696
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

0 commit comments

Comments
 (0)