Skip to content

Commit b3a7e3a

Browse files
committed
fix: addressing edge cases when resuming (continued)
1 parent dd79cb6 commit b3a7e3a

File tree

3 files changed

+226
-33
lines changed

3 files changed

+226
-33
lines changed

src/agents/run.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def prepare_input(
161161

162162
# On first call (when there are no generated items yet), include the original input
163163
if not generated_items:
164-
input_items.extend(ItemHelpers.input_to_new_input_list(original_input))
164+
# Normalize original_input items to ensure field names are in snake_case
165+
# (items from RunState deserialization may have camelCase)
166+
raw_input_list = ItemHelpers.input_to_new_input_list(original_input)
167+
# Filter out function_call items that don't have corresponding function_call_output
168+
# (API requires every function_call to have a function_call_output)
169+
filtered_input_list = AgentRunner._filter_incomplete_function_calls(raw_input_list)
170+
input_items.extend(AgentRunner._normalize_input_items(filtered_input_list))
165171

166172
# First, collect call_ids from tool_call_output_item items
167173
# (completed tool calls with outputs) and build a map of
@@ -737,8 +743,8 @@ async def run(
737743
original_user_input = run_state._original_input
738744
# Normalize items to remove top-level providerData (API doesn't accept it there)
739745
if isinstance(original_user_input, list):
740-
prepared_input: str | list[TResponseInputItem] = (
741-
AgentRunner._normalize_input_items(original_user_input)
746+
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
747+
original_user_input
742748
)
743749
else:
744750
prepared_input = original_user_input
@@ -833,8 +839,7 @@ async def run(
833839
if session is not None and generated_items:
834840
# Save tool_call_output_item items (the outputs)
835841
tool_output_items: list[RunItem] = [
836-
item for item in generated_items
837-
if item.type == "tool_call_output_item"
842+
item for item in generated_items if item.type == "tool_call_output_item"
838843
]
839844
# Also find and save the corresponding function_call items
840845
# (they might not be in session if the run was interrupted before saving)
@@ -1411,9 +1416,12 @@ async def _start_streaming(
14111416
# state's input, causing duplicate items.
14121417
if run_state is not None:
14131418
# Resuming from state - normalize items to remove top-level providerData
1419+
# and filter incomplete function_call pairs
14141420
if isinstance(starting_input, list):
1421+
# Filter incomplete function_call pairs before normalizing
1422+
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
14151423
prepared_input: str | list[TResponseInputItem] = (
1416-
AgentRunner._normalize_input_items(starting_input)
1424+
AgentRunner._normalize_input_items(filtered)
14171425
)
14181426
else:
14191427
prepared_input = starting_input
@@ -2414,20 +2422,82 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
24142422

24152423
return run_config.model_provider.get_model(agent.model)
24162424

2425+
@staticmethod
2426+
def _filter_incomplete_function_calls(
2427+
items: list[TResponseInputItem],
2428+
) -> list[TResponseInputItem]:
2429+
"""Filter out function_call items that don't have corresponding function_call_output.
2430+
2431+
The OpenAI API requires every function_call in an assistant message to have a
2432+
corresponding function_call_output (tool message). This function ensures only
2433+
complete pairs are included to prevent API errors.
2434+
2435+
IMPORTANT: This only filters incomplete function_call items. All other items
2436+
(messages, complete function_call pairs, etc.) are preserved to maintain
2437+
conversation history integrity.
2438+
2439+
Args:
2440+
items: List of input items to filter
2441+
2442+
Returns:
2443+
Filtered list with only complete function_call pairs. All non-function_call
2444+
items and complete function_call pairs are preserved.
2445+
"""
2446+
# First pass: collect call_ids from function_call_output/function_call_result items
2447+
completed_call_ids: set[str] = set()
2448+
for item in items:
2449+
if isinstance(item, dict):
2450+
item_type = item.get("type")
2451+
# Handle both API format (function_call_output) and
2452+
# protocol format (function_call_result)
2453+
if item_type in ("function_call_output", "function_call_result"):
2454+
call_id = item.get("call_id") or item.get("callId")
2455+
if call_id and isinstance(call_id, str):
2456+
completed_call_ids.add(call_id)
2457+
2458+
# Second pass: only include function_call items that have corresponding outputs
2459+
filtered: list[TResponseInputItem] = []
2460+
for item in items:
2461+
if isinstance(item, dict):
2462+
item_type = item.get("type")
2463+
if item_type == "function_call":
2464+
call_id = item.get("call_id") or item.get("callId")
2465+
# Only include if there's a corresponding
2466+
# function_call_output/function_call_result
2467+
if call_id and call_id in completed_call_ids:
2468+
filtered.append(item)
2469+
else:
2470+
# Include all non-function_call items
2471+
filtered.append(item)
2472+
else:
2473+
# Include non-dict items as-is
2474+
filtered.append(item)
2475+
2476+
return filtered
2477+
24172478
@staticmethod
24182479
def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
2419-
"""Normalize input items by removing top-level providerData/provider_data.
2420-
2480+
"""Normalize input items by removing top-level providerData/provider_data
2481+
and normalizing field names (callId -> call_id).
2482+
24212483
The OpenAI API doesn't accept providerData at the top level of input items.
24222484
providerData should only be in content where it belongs. This function removes
24232485
top-level providerData while preserving it in content.
2424-
2486+
2487+
Also normalizes field names from camelCase (callId) to snake_case (call_id)
2488+
to match API expectations.
2489+
2490+
Normalizes item types: converts 'function_call_result' to 'function_call_output'
2491+
to match API expectations.
2492+
24252493
Args:
24262494
items: List of input items to normalize
2427-
2495+
24282496
Returns:
24292497
Normalized list of input items
24302498
"""
2499+
from .run_state import _normalize_field_names
2500+
24312501
normalized: list[TResponseInputItem] = []
24322502
for item in items:
24332503
if isinstance(item, dict):
@@ -2437,6 +2507,18 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
24372507
# The API doesn't accept providerData at the top level of input items
24382508
normalized_item.pop("providerData", None)
24392509
normalized_item.pop("provider_data", None)
2510+
# Normalize item type: API expects 'function_call_output',
2511+
# not 'function_call_result'
2512+
item_type = normalized_item.get("type")
2513+
if item_type == "function_call_result":
2514+
normalized_item["type"] = "function_call_output"
2515+
item_type = "function_call_output"
2516+
# Remove invalid fields based on item type
2517+
# function_call_output items should not have 'name' field
2518+
if item_type == "function_call_output":
2519+
normalized_item.pop("name", None)
2520+
# Normalize field names (callId -> call_id, responseId -> response_id)
2521+
normalized_item = _normalize_field_names(normalized_item)
24402522
normalized.append(cast(TResponseInputItem, normalized_item))
24412523
else:
24422524
# For non-dict items, keep as-is (they should already be in correct format)
@@ -2483,10 +2565,14 @@ async def _prepare_input_with_session(
24832565
f"Invalid `session_input_callback` value: {session_input_callback}. "
24842566
"Choose between `None` or a custom callable function."
24852567
)
2486-
2568+
2569+
# Filter incomplete function_call pairs before normalizing
2570+
# (API requires every function_call to have a function_call_output)
2571+
filtered = cls._filter_incomplete_function_calls(merged)
2572+
24872573
# Normalize items to remove top-level providerData and deduplicate by ID
2488-
normalized = cls._normalize_input_items(merged)
2489-
2574+
normalized = cls._normalize_input_items(filtered)
2575+
24902576
# Deduplicate items by ID to prevent sending duplicate items to the API
24912577
# This can happen when resuming from state and items are already in the session
24922578
seen_ids: set[str] = set()
@@ -2498,13 +2584,13 @@ async def _prepare_input_with_session(
24982584
item_id = cast(str | None, item.get("id"))
24992585
elif hasattr(item, "id"):
25002586
item_id = cast(str | None, getattr(item, "id", None))
2501-
2587+
25022588
# Only add items we haven't seen before (or items without IDs)
25032589
if item_id is None or item_id not in seen_ids:
25042590
deduplicated.append(item)
25052591
if item_id:
25062592
seen_ids.add(item_id)
2507-
2593+
25082594
return deduplicated
25092595

25102596
@classmethod

src/agents/run_state.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51-
_current_turn_persisted_item_count: int = 0
52-
"""Tracks how many generated run items from this turn were already persisted to session.
53-
54-
When saving to session, we slice off only new entries. When a turn is interrupted
55-
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56-
continuing so pending tool outputs still get stored.
57-
"""
58-
5951
_current_agent: TAgent | None = None
6052
"""The agent currently handling the conversation."""
6153

@@ -250,13 +242,63 @@ def to_json(self) -> dict[str, Any]:
250242
}
251243
model_responses.append(response_dict)
252244

245+
# Normalize and camelize originalInput if it's a list of items
246+
# Convert API format to protocol format to match TypeScript schema
247+
# Protocol expects function_call_result (not function_call_output)
248+
original_input_serialized = self._original_input
249+
if isinstance(original_input_serialized, list):
250+
# First pass: build a map of call_id -> function_call name
251+
# to help convert function_call_output to function_call_result
252+
call_id_to_name: dict[str, str] = {}
253+
for item in original_input_serialized:
254+
if isinstance(item, dict):
255+
item_type = item.get("type")
256+
call_id = item.get("call_id") or item.get("callId")
257+
name = item.get("name")
258+
if item_type == "function_call" and call_id and name:
259+
call_id_to_name[call_id] = name
260+
261+
normalized_items = []
262+
for item in original_input_serialized:
263+
if isinstance(item, dict):
264+
# Create a copy to avoid modifying the original
265+
normalized_item = dict(item)
266+
# Remove session/conversation metadata fields that shouldn't be in originalInput
267+
# These are not part of the input protocol schema
268+
normalized_item.pop("id", None)
269+
normalized_item.pop("created_at", None)
270+
# Remove top-level providerData/provider_data (protocol allows it but
271+
# we remove it for cleaner serialization)
272+
normalized_item.pop("providerData", None)
273+
normalized_item.pop("provider_data", None)
274+
# Convert API format to protocol format
275+
# API uses function_call_output, protocol uses function_call_result
276+
item_type = normalized_item.get("type")
277+
call_id = normalized_item.get("call_id") or normalized_item.get("callId")
278+
if item_type == "function_call_output":
279+
# Convert to protocol format: function_call_result
280+
normalized_item["type"] = "function_call_result"
281+
# Protocol format requires status field (default to 'completed')
282+
if "status" not in normalized_item:
283+
normalized_item["status"] = "completed"
284+
# Protocol format requires name field
285+
# Look it up from the corresponding function_call if missing
286+
if "name" not in normalized_item and call_id:
287+
normalized_item["name"] = call_id_to_name.get(call_id, "")
288+
# Normalize field names to camelCase for JSON (call_id -> callId)
289+
normalized_item = self._camelize_field_names(normalized_item)
290+
normalized_items.append(normalized_item)
291+
else:
292+
normalized_items.append(item)
293+
original_input_serialized = normalized_items
294+
253295
result = {
254296
"$schemaVersion": CURRENT_SCHEMA_VERSION,
255297
"currentTurn": self._current_turn,
256298
"currentAgent": {
257299
"name": self._current_agent.name,
258300
},
259-
"originalInput": self._original_input,
301+
"originalInput": original_input_serialized,
260302
"modelResponses": model_responses,
261303
"context": {
262304
"usage": {
@@ -345,7 +387,6 @@ def to_json(self) -> dict[str, Any]:
345387
if self._last_processed_response
346388
else None
347389
)
348-
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
349390
result["trace"] = None
350391

351392
return result
@@ -571,18 +612,29 @@ async def from_string(
571612
context.usage = usage
572613
context._rebuild_approvals(context_data.get("approvals", {}))
573614

615+
# Normalize originalInput to remove providerData fields that may have been
616+
# included by TypeScript serialization. These fields are metadata and should
617+
# not be sent to the API.
618+
original_input_raw = state_json["originalInput"]
619+
if isinstance(original_input_raw, list):
620+
# Normalize each item in the list to remove providerData fields
621+
normalized_original_input = [
622+
_normalize_field_names(item) if isinstance(item, dict) else item
623+
for item in original_input_raw
624+
]
625+
else:
626+
# If it's a string, use it as-is
627+
normalized_original_input = original_input_raw
628+
574629
# Create the RunState instance
575630
state = RunState(
576631
context=context,
577-
original_input=state_json["originalInput"],
632+
original_input=normalized_original_input,
578633
starting_agent=current_agent,
579634
max_turns=state_json["maxTurns"],
580635
)
581636

582637
state._current_turn = state_json["currentTurn"]
583-
state._current_turn_persisted_item_count = state_json.get(
584-
"currentTurnPersistedItemCount", 0
585-
)
586638

587639
# Reconstruct model responses
588640
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -679,18 +731,29 @@ async def from_json(
679731
context.usage = usage
680732
context._rebuild_approvals(context_data.get("approvals", {}))
681733

734+
# Normalize originalInput to remove providerData fields that may have been
735+
# included by TypeScript serialization. These fields are metadata and should
736+
# not be sent to the API.
737+
original_input_raw = state_json["originalInput"]
738+
if isinstance(original_input_raw, list):
739+
# Normalize each item in the list to remove providerData fields
740+
normalized_original_input = [
741+
_normalize_field_names(item) if isinstance(item, dict) else item
742+
for item in original_input_raw
743+
]
744+
else:
745+
# If it's a string, use it as-is
746+
normalized_original_input = original_input_raw
747+
682748
# Create the RunState instance
683749
state = RunState(
684750
context=context,
685-
original_input=state_json["originalInput"],
751+
original_input=normalized_original_input,
686752
starting_agent=current_agent,
687753
max_turns=state_json["maxTurns"],
688754
)
689755

690756
state._current_turn = state_json["currentTurn"]
691-
state._current_turn_persisted_item_count = state_json.get(
692-
"currentTurnPersistedItemCount", 0
693-
)
694757

695758
# Reconstruct model responses
696759
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

tests/test_run_state.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,50 @@ async def test_deserializes_various_item_types(self):
507507
assert isinstance(new_state._generated_items[1], ToolCallItem)
508508
assert isinstance(new_state._generated_items[2], ToolCallOutputItem)
509509

510+
async def test_serializes_original_input_with_function_call_output(self):
511+
"""Test that originalInput with function_call_output items is converted to protocol."""
512+
context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})
513+
agent = Agent(name="TestAgent")
514+
515+
# Create originalInput with function_call_output (API format)
516+
# This simulates items from session that are in API format
517+
original_input = [
518+
{
519+
"type": "function_call",
520+
"call_id": "call_123",
521+
"name": "test_tool",
522+
"arguments": '{"arg": "value"}',
523+
},
524+
{
525+
"type": "function_call_output",
526+
"call_id": "call_123",
527+
"output": "result",
528+
},
529+
]
530+
531+
state = RunState(
532+
context=context, original_input=original_input, starting_agent=agent, max_turns=5
533+
)
534+
535+
# Serialize - should convert function_call_output to function_call_result
536+
json_data = state.to_json()
537+
538+
# Verify originalInput was converted to protocol format
539+
assert isinstance(json_data["originalInput"], list)
540+
assert len(json_data["originalInput"]) == 2
541+
542+
# First item should remain function_call (with camelCase)
543+
assert json_data["originalInput"][0]["type"] == "function_call"
544+
assert json_data["originalInput"][0]["callId"] == "call_123"
545+
assert json_data["originalInput"][0]["name"] == "test_tool"
546+
547+
# Second item should be converted to function_call_result (protocol format)
548+
assert json_data["originalInput"][1]["type"] == "function_call_result"
549+
assert json_data["originalInput"][1]["callId"] == "call_123"
550+
assert json_data["originalInput"][1]["name"] == "test_tool" # Looked up from function_call
551+
assert json_data["originalInput"][1]["status"] == "completed" # Added default
552+
assert json_data["originalInput"][1]["output"] == "result"
553+
510554
async def test_deserialization_handles_unknown_agent_gracefully(self):
511555
"""Test that deserialization skips items with unknown agents."""
512556
context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})

0 commit comments

Comments
 (0)