Skip to content

Commit 521b3c4

Browse files
authored
Support returning multi-modal content from tools (#1517)
1 parent ea2bbc5 commit 521b3c4

19 files changed

+1803
-67
lines changed

docs/tools.md

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ There are a number of ways to register tools with an agent:
1515
* via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext]
1616
* via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool]
1717

18+
## Registering Function Tools via Decorator
19+
1820
`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context.
1921

2022
Here's an example using both:
@@ -188,7 +190,7 @@ sequenceDiagram
188190
Note over Agent: Game session complete
189191
```
190192

191-
## Registering Function Tools via kwarg
193+
## Registering Function Tools via Agent Argument
192194

193195
As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to reuse tools, and can also give more fine-grained control over the tools.
194196

@@ -244,6 +246,67 @@ print(dice_result['b'].output)
244246

245247
_(This example is complete, it can be run "as is")_
246248

249+
## Function Tool Output
250+
251+
Tools can return anything that Pydantic can serialize to JSON, as well as audio, video, image or document content depending on the types of [multi-modal input](input.md) the model supports:
252+
253+
```python {title="function_tool_output.py"}
254+
from datetime import datetime
255+
256+
from pydantic import BaseModel
257+
258+
from pydantic_ai import Agent, DocumentUrl, ImageUrl
259+
from pydantic_ai.models.openai import OpenAIResponsesModel
260+
261+
262+
class User(BaseModel):
263+
name: str
264+
age: int
265+
266+
267+
agent = Agent(model=OpenAIResponsesModel('gpt-4o'))
268+
269+
270+
@agent.tool_plain
271+
def get_current_time() -> datetime:
272+
return datetime.now()
273+
274+
275+
@agent.tool_plain
276+
def get_user() -> User:
277+
return User(name='John', age=30)
278+
279+
280+
@agent.tool_plain
281+
def get_company_logo() -> ImageUrl:
282+
return ImageUrl(url='https://iili.io/3Hs4FMg.png')
283+
284+
285+
@agent.tool_plain
286+
def get_document() -> DocumentUrl:
287+
return DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')
288+
289+
290+
result = agent.run_sync('What time is it?')
291+
print(result.output)
292+
#> The current time is 10:45 PM on April 17, 2025.
293+
294+
result = agent.run_sync('What is the user name?')
295+
print(result.output)
296+
#> The user's name is John.
297+
298+
result = agent.run_sync('What is the company name in the logo?')
299+
print(result.output)
300+
#> The company name in the logo is "Pydantic."
301+
302+
result = agent.run_sync('What is the main content of the document?')
303+
print(result.output)
304+
#> The document contains just the text "Dummy PDF file."
305+
```
306+
_(This example is complete, it can be run "as is")_
307+
308+
Some models (e.g. Gemini) natively support semi-structured return values, while some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
309+
247310
## Function Tools vs. Structured Outputs
248311

249312
As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output.
@@ -307,8 +370,6 @@ agent.run_sync('hello', model=FunctionModel(print_schema))
307370

308371
_(This example is complete, it can be run "as is")_
309372

310-
The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
311-
312373
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object.
313374

314375
Here's an example where we use [`TestModel.last_model_request_parameters`][pydantic_ai.models.test.TestModel.last_model_request_parameters] to inspect the tool schema that would be passed to the model.

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
546546
)
547547

548548

549-
async def process_function_tools(
549+
async def process_function_tools( # noqa C901
550550
tool_calls: list[_messages.ToolCallPart],
551551
output_tool_name: str | None,
552552
output_tool_call_id: str | None,
@@ -632,6 +632,8 @@ async def process_function_tools(
632632
if not calls_to_run:
633633
return
634634

635+
user_parts: list[_messages.UserPromptPart] = []
636+
635637
# Run all tool tasks in parallel
636638
results_by_index: dict[int, _messages.ModelRequestPart] = {}
637639
with ctx.deps.tracer.start_as_current_span(
@@ -645,14 +647,32 @@ async def process_function_tools(
645647
asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name)
646648
for tool, call in calls_to_run
647649
]
650+
651+
file_index = 1
652+
648653
pending = tasks
649654
while pending:
650655
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
651656
for task in done:
652657
index = tasks.index(task)
653658
result = task.result()
654659
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
655-
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
660+
661+
if isinstance(result, _messages.RetryPromptPart):
662+
results_by_index[index] = result
663+
elif isinstance(result, _messages.ToolReturnPart):
664+
if isinstance(result.content, _messages.MultiModalContentTypes):
665+
user_parts.append(
666+
_messages.UserPromptPart(
667+
content=[f'This is file {file_index}:', result.content],
668+
timestamp=result.timestamp,
669+
part_kind='user-prompt',
670+
)
671+
)
672+
673+
result.content = f'See file {file_index}.'
674+
file_index += 1
675+
656676
results_by_index[index] = result
657677
else:
658678
assert_never(result)
@@ -662,6 +682,8 @@ async def process_function_tools(
662682
for k in sorted(results_by_index):
663683
output_parts.append(results_by_index[k])
664684

685+
output_parts.extend(user_parts)
686+
665687

666688
async def _tool_from_mcp_server(
667689
tool_name: str,

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def format(self) -> str:
253253

254254
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'
255255

256+
# Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
257+
MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)
258+
256259

257260
def _document_format(media_type: str) -> DocumentFormat:
258261
if media_type == 'application/pdf':

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,20 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]:
483483
assert_never(message)
484484
if instructions := self._get_instructions(messages):
485485
mistral_messages.insert(0, MistralSystemMessage(content=instructions))
486-
return mistral_messages
486+
487+
# Post-process messages to insert fake assistant message after tool message if followed by user message
488+
# to work around `Unexpected role 'user' after role 'tool'` error.
489+
processed_messages: list[MistralMessages] = []
490+
for i, current_message in enumerate(mistral_messages):
491+
processed_messages.append(current_message)
492+
493+
if isinstance(current_message, MistralToolMessage) and i + 1 < len(mistral_messages):
494+
next_message = mistral_messages[i + 1]
495+
if isinstance(next_message, MistralUserMessage):
496+
# Insert a dummy assistant message
497+
processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')]))
498+
499+
return processed_messages
487500

488501
def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage:
489502
content: str | list[MistralContentChunk]

tests/models/cassettes/test_anthropic/test_image_as_binary_content_tool_response.yaml

Lines changed: 153 additions & 0 deletions
Large diffs are not rendered by default.

tests/models/cassettes/test_gemini/test_image_as_binary_content_tool_response.yaml

Lines changed: 150 additions & 0 deletions
Large diffs are not rendered by default.

tests/models/cassettes/test_groq/test_image_as_binary_content_input.yaml

Lines changed: 21 additions & 19 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)