Skip to content

Commit 96be03d

Browse files
Kludexdmontagu
andauthored
Support multimodal inputs (#971)
Co-authored-by: David Montague <[email protected]>
1 parent 4cd2603 commit 96be03d

40 files changed

+8814
-168
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@ repos:
4545
rev: v2.3.0
4646
hooks:
4747
- id: codespell
48+
args: ['--skip', 'tests/models/cassettes/*']
4849
additional_dependencies:
4950
- tomli

docs/input.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Image and Audio Input
2+
3+
Some LLMs are now capable of understanding both audio and image content.
4+
5+
## Image Input
6+
7+
!!! info
8+
Some models do not support image input. Please check the model's documentation to confirm whether it supports image input.
9+
10+
If you have a direct URL for the image, you can use [`ImageUrl`][pydantic_ai.ImageUrl]:
11+
12+
```py {title="main.py" test="skip" lint="skip"}
13+
from pydantic_ai import Agent, ImageUrl
14+
15+
image_url = ImageUrl(url='https://iili.io/3Hs4FMg.png')
16+
17+
agent = Agent(model='openai:gpt-4o')
18+
result = agent.run_sync(
19+
[
20+
'What company is this logo from?',
21+
ImageUrl(url='https://iili.io/3Hs4FMg.png'),
22+
]
23+
)
24+
print(result.data)
25+
#> This is the logo for Pydantic, a data validation and settings management library in Python.
26+
```
27+
28+
If you have the image locally, you can also use [`BinaryContent`][pydantic_ai.BinaryContent]:
29+
30+
```py {title="main.py" test="skip" lint="skip"}
31+
import httpx
32+
33+
from pydantic_ai import Agent, BinaryContent
34+
35+
image_response = httpx.get('https://iili.io/3Hs4FMg.png') # Pydantic logo
36+
37+
agent = Agent(model='openai:gpt-4o')
38+
result = agent.run_sync(
39+
[
40+
'What company is this logo from?',
41+
BinaryContent(data=image_response.content, media_type='image/png'), # (1)!
42+
]
43+
)
44+
print(result.data)
45+
#> This is the logo for Pydantic, a data validation and settings management library in Python.
46+
```
47+
48+
1. To ensure the example is runnable we download this image from the web, but you can also use `Path().read_bytes()` to read a local file's contents.
49+
50+
## Audio Input
51+
52+
!!! info
53+
Some models do not support audio input. Please check the model's documentation to confirm whether it supports audio input.
54+
55+
You can provide audio input using either [`AudioUrl`][pydantic_ai.AudioUrl] or [`BinaryContent`][pydantic_ai.BinaryContent]. The process is analogous to the examples above.

examples/pydantic_ai_examples/chat_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def to_chat_message(m: ModelMessage) -> ChatMessage:
8989
first_part = m.parts[0]
9090
if isinstance(m, ModelRequest):
9191
if isinstance(first_part, UserPromptPart):
92+
assert isinstance(first_part.content, str)
9293
return {
9394
'role': 'user',
9495
'timestamp': first_part.timestamp.isoformat(),

mkdocs.yml

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,52 +16,53 @@ nav:
1616
- contributing.md
1717
- troubleshooting.md
1818
- Documentation:
19-
- agents.md
20-
- models.md
21-
- dependencies.md
22-
- tools.md
23-
- results.md
24-
- message-history.md
25-
- testing-evals.md
26-
- logfire.md
27-
- multi-agent-applications.md
28-
- graph.md
19+
- agents.md
20+
- models.md
21+
- dependencies.md
22+
- tools.md
23+
- results.md
24+
- message-history.md
25+
- testing-evals.md
26+
- logfire.md
27+
- multi-agent-applications.md
28+
- graph.md
29+
- input.md
2930
- Examples:
30-
- examples/index.md
31-
- examples/pydantic-model.md
32-
- examples/weather-agent.md
33-
- examples/bank-support.md
34-
- examples/sql-gen.md
35-
- examples/flight-booking.md
36-
- examples/rag.md
37-
- examples/stream-markdown.md
38-
- examples/stream-whales.md
39-
- examples/chat-app.md
40-
- examples/question-graph.md
31+
- examples/index.md
32+
- examples/pydantic-model.md
33+
- examples/weather-agent.md
34+
- examples/bank-support.md
35+
- examples/sql-gen.md
36+
- examples/flight-booking.md
37+
- examples/rag.md
38+
- examples/stream-markdown.md
39+
- examples/stream-whales.md
40+
- examples/chat-app.md
41+
- examples/question-graph.md
4142
- API Reference:
42-
- api/agent.md
43-
- api/tools.md
44-
- api/result.md
45-
- api/messages.md
46-
- api/exceptions.md
47-
- api/settings.md
48-
- api/usage.md
49-
- api/format_as_xml.md
50-
- api/models/base.md
51-
- api/models/openai.md
52-
- api/models/anthropic.md
53-
- api/models/cohere.md
54-
- api/models/gemini.md
55-
- api/models/vertexai.md
56-
- api/models/groq.md
57-
- api/models/mistral.md
58-
- api/models/test.md
59-
- api/models/function.md
60-
- api/pydantic_graph/graph.md
61-
- api/pydantic_graph/nodes.md
62-
- api/pydantic_graph/state.md
63-
- api/pydantic_graph/mermaid.md
64-
- api/pydantic_graph/exceptions.md
43+
- api/agent.md
44+
- api/tools.md
45+
- api/result.md
46+
- api/messages.md
47+
- api/exceptions.md
48+
- api/settings.md
49+
- api/usage.md
50+
- api/format_as_xml.md
51+
- api/models/base.md
52+
- api/models/openai.md
53+
- api/models/anthropic.md
54+
- api/models/cohere.md
55+
- api/models/gemini.md
56+
- api/models/vertexai.md
57+
- api/models/groq.md
58+
- api/models/mistral.md
59+
- api/models/test.md
60+
- api/models/function.md
61+
- api/pydantic_graph/graph.md
62+
- api/pydantic_graph/nodes.md
63+
- api/pydantic_graph/state.md
64+
- api/pydantic_graph/mermaid.md
65+
- api/pydantic_graph/exceptions.md
6566

6667
extra:
6768
# hide the "Made with Material for MkDocs" message
@@ -100,12 +101,12 @@ theme:
100101
- content.code.copy
101102
- content.code.select
102103
- navigation.path
103-
# - navigation.expand
104+
# - navigation.expand
104105
- navigation.indexes
105106
- navigation.sections
106107
- navigation.tracking
107108
- toc.follow
108-
# - navigation.tabs # don't use navbar tabs
109+
# - navigation.tabs # don't use navbar tabs
109110
logo: "img/logo-white.svg"
110111
favicon: "favicon.ico"
111112

@@ -151,7 +152,7 @@ markdown_extensions:
151152
emoji_generator: !!python/name:material.extensions.emoji.to_svg
152153
options:
153154
custom_icons:
154-
- docs/.overrides/.icons
155+
- docs/.overrides/.icons
155156
- pymdownx.tabbed:
156157
alternate_style: true
157158
- pymdownx.tasklist:
@@ -190,6 +191,6 @@ plugins:
190191
# waiting for https://github.com/encode/httpx/discussions/3091#discussioncomment-11205594
191192

192193
hooks:
193-
- 'docs/.hooks/main.py'
194-
- 'docs/.hooks/build_llms_txt.py'
195-
- 'docs/.hooks/algolia.py'
194+
- "docs/.hooks/main.py"
195+
- "docs/.hooks/build_llms_txt.py"
196+
- "docs/.hooks/algolia.py"

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22

33
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
44
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
5+
from .messages import AudioUrl, BinaryContent, ImageUrl
56
from .tools import RunContext, Tool
67

78
__all__ = (
9+
'__version__',
10+
# agent
811
'Agent',
912
'EndStrategy',
1013
'HandleResponseNode',
1114
'ModelRequestNode',
1215
'UserPromptNode',
1316
'capture_run_messages',
14-
'RunContext',
15-
'Tool',
17+
# exceptions
1618
'AgentRunError',
1719
'ModelRetry',
1820
'UnexpectedModelBehavior',
1921
'UsageLimitExceeded',
2022
'UserError',
21-
'__version__',
23+
# messages
24+
'ImageUrl',
25+
'AudioUrl',
26+
'BinaryContent',
27+
# tools
28+
'Tool',
29+
'RunContext',
2230
)
2331
__version__ = version('pydantic_ai_slim')

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import dataclasses
55
from abc import ABC
6-
from collections.abc import AsyncIterator, Iterator
6+
from collections.abc import AsyncIterator, Iterator, Sequence
77
from contextlib import asynccontextmanager, contextmanager
88
from contextvars import ContextVar
99
from dataclasses import field
@@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
8989

9090
user_deps: DepsT
9191

92-
prompt: str
92+
prompt: str | Sequence[_messages.UserContent]
9393
new_message_index: int
9494

9595
model: models.Model
@@ -109,7 +109,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
109109

110110
@dataclasses.dataclass
111111
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
112-
user_prompt: str
112+
user_prompt: str | Sequence[_messages.UserContent]
113113

114114
system_prompts: tuple[str, ...]
115115
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
@@ -135,7 +135,10 @@ async def _get_first_message(
135135
return next_message
136136

137137
async def _prepare_messages(
138-
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
138+
self,
139+
user_prompt: str | Sequence[_messages.UserContent],
140+
message_history: list[_messages.ModelMessage] | None,
141+
run_context: RunContext[DepsT],
139142
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140143
try:
141144
ctx_messages = get_captured_run_messages()

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __init__(
220220
@overload
221221
async def run(
222222
self,
223-
user_prompt: str,
223+
user_prompt: str | Sequence[_messages.UserContent],
224224
*,
225225
result_type: None = None,
226226
message_history: list[_messages.ModelMessage] | None = None,
@@ -235,7 +235,7 @@ async def run(
235235
@overload
236236
async def run(
237237
self,
238-
user_prompt: str,
238+
user_prompt: str | Sequence[_messages.UserContent],
239239
*,
240240
result_type: type[RunResultDataT],
241241
message_history: list[_messages.ModelMessage] | None = None,
@@ -249,7 +249,7 @@ async def run(
249249

250250
async def run(
251251
self,
252-
user_prompt: str,
252+
user_prompt: str | Sequence[_messages.UserContent],
253253
*,
254254
result_type: type[RunResultDataT] | None = None,
255255
message_history: list[_messages.ModelMessage] | None = None,
@@ -313,7 +313,7 @@ async def main():
313313
@contextmanager
314314
def iter(
315315
self,
316-
user_prompt: str,
316+
user_prompt: str | Sequence[_messages.UserContent],
317317
*,
318318
result_type: type[RunResultDataT] | None = None,
319319
message_history: list[_messages.ModelMessage] | None = None,
@@ -466,7 +466,7 @@ async def main():
466466
@overload
467467
def run_sync(
468468
self,
469-
user_prompt: str,
469+
user_prompt: str | Sequence[_messages.UserContent],
470470
*,
471471
message_history: list[_messages.ModelMessage] | None = None,
472472
model: models.Model | models.KnownModelName | None = None,
@@ -480,7 +480,7 @@ def run_sync(
480480
@overload
481481
def run_sync(
482482
self,
483-
user_prompt: str,
483+
user_prompt: str | Sequence[_messages.UserContent],
484484
*,
485485
result_type: type[RunResultDataT] | None,
486486
message_history: list[_messages.ModelMessage] | None = None,
@@ -494,7 +494,7 @@ def run_sync(
494494

495495
def run_sync(
496496
self,
497-
user_prompt: str,
497+
user_prompt: str | Sequence[_messages.UserContent],
498498
*,
499499
result_type: type[RunResultDataT] | None = None,
500500
message_history: list[_messages.ModelMessage] | None = None,
@@ -555,7 +555,7 @@ def run_sync(
555555
@overload
556556
def run_stream(
557557
self,
558-
user_prompt: str,
558+
user_prompt: str | Sequence[_messages.UserContent],
559559
*,
560560
result_type: None = None,
561561
message_history: list[_messages.ModelMessage] | None = None,
@@ -570,7 +570,7 @@ def run_stream(
570570
@overload
571571
def run_stream(
572572
self,
573-
user_prompt: str,
573+
user_prompt: str | Sequence[_messages.UserContent],
574574
*,
575575
result_type: type[RunResultDataT],
576576
message_history: list[_messages.ModelMessage] | None = None,
@@ -585,7 +585,7 @@ def run_stream(
585585
@asynccontextmanager
586586
async def run_stream( # noqa C901
587587
self,
588-
user_prompt: str,
588+
user_prompt: str | Sequence[_messages.UserContent],
589589
*,
590590
result_type: type[RunResultDataT] | None = None,
591591
message_history: list[_messages.ModelMessage] | None = None,

0 commit comments

Comments
 (0)