Skip to content

Commit 9a4d876

Browse files
committed
fix: address issues around resuming run state with conversation history
1 parent ae7ea4f commit 9a4d876

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

src/agents/run.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,12 +728,17 @@ async def run(
728728
# Check if we're resuming from a RunState
729729
is_resumed_state = isinstance(input, RunState)
730730
run_state: RunState[TContext] | None = None
731+
prepared_input: str | list[TResponseInputItem]
731732

732733
if is_resumed_state:
733734
# Resuming from a saved state
734735
run_state = cast(RunState[TContext], input)
735736
original_user_input = run_state._original_input
736-
prepared_input = run_state._original_input
737+
738+
if isinstance(run_state._original_input, list):
739+
prepared_input = self._merge_provider_data_in_items(run_state._original_input)
740+
else:
741+
prepared_input = run_state._original_input
737742

738743
# Override context with the state's context if not provided
739744
if context is None and run_state._context is not None:
@@ -803,6 +808,9 @@ async def run(
803808
# If resuming from an interrupted state, execute approved tools first
804809
if is_resumed_state and run_state is not None and run_state._current_step is not None:
805810
if isinstance(run_state._current_step, NextStepInterruption):
811+
# Track items before executing approved tools
812+
items_before_execution = len(generated_items)
813+
806814
# We're resuming from an interruption - execute approved tools
807815
await self._execute_approved_tools(
808816
agent=current_agent,
@@ -812,6 +820,16 @@ async def run(
812820
run_config=run_config,
813821
hooks=hooks,
814822
)
823+
824+
# Save the newly executed tool outputs to the session
825+
new_tool_outputs: list[RunItem] = [
826+
item
827+
for item in generated_items[items_before_execution:]
828+
if item.type == "tool_call_output_item"
829+
]
830+
if new_tool_outputs and session is not None:
831+
await self._save_result_to_session(session, [], new_tool_outputs)
832+
815833
# Clear the current step since we've handled it
816834
run_state._current_step = None
817835

@@ -1102,7 +1120,14 @@ def run_streamed(
11021120

11031121
if is_resumed_state:
11041122
run_state = cast(RunState[TContext], input)
1105-
input_for_result = run_state._original_input
1123+
1124+
if isinstance(run_state._original_input, list):
1125+
input_for_result = AgentRunner._merge_provider_data_in_items(
1126+
run_state._original_input
1127+
)
1128+
else:
1129+
input_for_result = run_state._original_input
1130+
11061131
# Use context from RunState if not provided
11071132
if context is None and run_state._context is not None:
11081133
context = run_state._context.context
@@ -1302,6 +1327,9 @@ async def _start_streaming(
13021327
# If resuming from an interrupted state, execute approved tools first
13031328
if run_state is not None and run_state._current_step is not None:
13041329
if isinstance(run_state._current_step, NextStepInterruption):
1330+
# Track items before executing approved tools
1331+
items_before_execution = len(streamed_result.new_items)
1332+
13051333
# We're resuming from an interruption - execute approved tools
13061334
await cls._execute_approved_tools_static(
13071335
agent=current_agent,
@@ -1311,6 +1339,16 @@ async def _start_streaming(
13111339
run_config=run_config,
13121340
hooks=hooks,
13131341
)
1342+
1343+
# Save the newly executed tool outputs to the session
1344+
new_tool_outputs: list[RunItem] = [
1345+
item
1346+
for item in streamed_result.new_items[items_before_execution:]
1347+
if item.type == "tool_call_output_item"
1348+
]
1349+
if new_tool_outputs and session is not None:
1350+
await cls._save_result_to_session(session, [], new_tool_outputs)
1351+
13141352
# Clear the current step since we've handled it
13151353
run_state._current_step = None
13161354

@@ -1581,6 +1619,8 @@ async def _run_single_turn_streamed(
15811619
input_item = item.to_input_item()
15821620
input.append(input_item)
15831621

1622+
input = cls._merge_provider_data_in_items(input)
1623+
15841624
# THIS IS THE RESOLVED CONFLICT BLOCK
15851625
filtered = await cls._maybe_filter_model_input(
15861626
agent=agent,
@@ -1920,6 +1960,8 @@ async def _run_single_turn(
19201960
input_item = generated_item.to_input_item()
19211961
input.append(input_item)
19221962

1963+
input = cls._merge_provider_data_in_items(input)
1964+
19231965
new_response = await cls._get_new_response(
19241966
agent,
19251967
system_prompt,
@@ -2254,6 +2296,30 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
22542296

22552297
return run_config.model_provider.get_model(agent.model)
22562298

2299+
@classmethod
2300+
def _merge_provider_data_in_items(
2301+
cls, items: list[TResponseInputItem]
2302+
) -> list[TResponseInputItem]:
2303+
"""Remove providerData fields from items."""
2304+
result = []
2305+
for item in items:
2306+
if isinstance(item, dict):
2307+
merged_item = dict(item)
2308+
# Pop both possible keys (providerData and provider_data)
2309+
provider_data = merged_item.pop("providerData", None)
2310+
if provider_data is None:
2311+
provider_data = merged_item.pop("provider_data", None)
2312+
# Merge contents if providerData exists and is a dict
2313+
if isinstance(provider_data, dict):
2314+
# Merge provider_data contents, with existing fields taking precedence
2315+
for key, value in provider_data.items():
2316+
if key not in merged_item:
2317+
merged_item[key] = value
2318+
result.append(cast(TResponseInputItem, merged_item))
2319+
else:
2320+
result.append(item)
2321+
return result
2322+
22572323
@classmethod
22582324
async def _prepare_input_with_session(
22592325
cls,
@@ -2277,6 +2343,7 @@ async def _prepare_input_with_session(
22772343

22782344
# Get previous conversation history
22792345
history = await session.get_items()
2346+
history = cls._merge_provider_data_in_items(history)
22802347

22812348
# Convert input to list format
22822349
new_input_list = ItemHelpers.input_to_new_input_list(input)
@@ -2286,7 +2353,9 @@ async def _prepare_input_with_session(
22862353
elif callable(session_input_callback):
22872354
res = session_input_callback(history, new_input_list)
22882355
if inspect.isawaitable(res):
2289-
return await res
2356+
res = await res
2357+
if isinstance(res, list):
2358+
res = cls._merge_provider_data_in_items(res)
22902359
return res
22912360
else:
22922361
raise UserError(

0 commit comments

Comments
 (0)