Skip to content

Commit ec87153

Browse files
authored
Fix OpenAI Responses API tool calls with reasoning (#2869)
1 parent 5bab307 commit ec87153

File tree

3 files changed

+513
-27
lines changed

3 files changed

+513
-27
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,9 @@ def _process_response(self, response: responses.Response) -> ModelResponse:
878878
if isinstance(content, responses.ResponseOutputText): # pragma: no branch
879879
items.append(TextPart(content.text))
880880
elif isinstance(item, responses.ResponseFunctionToolCall):
881-
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
881+
items.append(
882+
ToolCallPart(item.name, item.arguments, tool_call_id=_combine_tool_call_ids(item.call_id, item.id))
883+
)
882884

883885
finish_reason: FinishReason | None = None
884886
provider_details: dict[str, Any] | None = None
@@ -1084,27 +1086,29 @@ async def _map_messages( # noqa: C901
10841086
elif isinstance(part, UserPromptPart):
10851087
openai_messages.append(await self._map_user_prompt(part))
10861088
elif isinstance(part, ToolReturnPart):
1087-
openai_messages.append(
1088-
FunctionCallOutput(
1089-
type='function_call_output',
1090-
call_id=_guard_tool_call_id(t=part),
1091-
output=part.model_response_str(),
1092-
)
1089+
call_id = _guard_tool_call_id(t=part)
1090+
call_id, _ = _split_combined_tool_call_id(call_id)
1091+
item = FunctionCallOutput(
1092+
type='function_call_output',
1093+
call_id=call_id,
1094+
output=part.model_response_str(),
10931095
)
1096+
openai_messages.append(item)
10941097
elif isinstance(part, RetryPromptPart):
10951098
# TODO(Marcelo): How do we test this conditional branch?
10961099
if part.tool_name is None: # pragma: no cover
10971100
openai_messages.append(
10981101
Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}])
10991102
)
11001103
else:
1101-
openai_messages.append(
1102-
FunctionCallOutput(
1103-
type='function_call_output',
1104-
call_id=_guard_tool_call_id(t=part),
1105-
output=part.model_response(),
1106-
)
1104+
call_id = _guard_tool_call_id(t=part)
1105+
call_id, _ = _split_combined_tool_call_id(call_id)
1106+
item = FunctionCallOutput(
1107+
type='function_call_output',
1108+
call_id=call_id,
1109+
output=part.model_response(),
11071110
)
1111+
openai_messages.append(item)
11081112
else:
11091113
assert_never(part)
11101114
elif isinstance(message, ModelResponse):
@@ -1141,12 +1145,18 @@ async def _map_messages( # noqa: C901
11411145

11421146
@staticmethod
11431147
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
1144-
return responses.ResponseFunctionToolCallParam(
1145-
arguments=t.args_as_json_str(),
1146-
call_id=_guard_tool_call_id(t=t),
1148+
call_id = _guard_tool_call_id(t=t)
1149+
call_id, id = _split_combined_tool_call_id(call_id)
1150+
1151+
param = responses.ResponseFunctionToolCallParam(
11471152
name=t.tool_name,
1153+
arguments=t.args_as_json_str(),
1154+
call_id=call_id,
11481155
type='function_call',
11491156
)
1157+
if id: # pragma: no branch
1158+
param['id'] = id
1159+
return param
11501160

11511161
def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam:
11521162
response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = {
@@ -1365,7 +1375,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
13651375
vendor_part_id=chunk.item.id,
13661376
tool_name=chunk.item.name,
13671377
args=chunk.item.arguments,
1368-
tool_call_id=chunk.item.call_id,
1378+
tool_call_id=_combine_tool_call_ids(chunk.item.call_id, chunk.item.id),
13691379
)
13701380
elif isinstance(chunk.item, responses.ResponseReasoningItem):
13711381
pass
@@ -1506,3 +1516,17 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
15061516
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
15071517
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
15081518
return u
1519+
1520+
1521+
def _combine_tool_call_ids(call_id: str, id: str | None) -> str:
1522+
# When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
1523+
# Our `ToolCallPart` has only the `call_id` field, so we combine the two fields into a single string.
1524+
return f'{call_id}|{id}' if id else call_id
1525+
1526+
1527+
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
1528+
if '|' in combined_id:
1529+
call_id, id = combined_id.split('|', 1)
1530+
return call_id, id
1531+
else:
1532+
return combined_id, None # pragma: no cover

0 commit comments

Comments
 (0)