Skip to content

Commit 48094d2

Browse files
committed
add support for model instances, deps and settings
1 parent 32cc5c5 commit 48094d2

File tree

6 files changed

+242
-94
lines changed

6 files changed

+242
-94
lines changed

pydantic_ai_slim/pydantic_ai/_cli/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@
5757

5858
# CLI-supported tool IDs (excludes deprecated and config-requiring tools)
5959
_CLI_TOOL_IDS = sorted(
60-
k.kind
61-
for k in get_builtin_tool_types()
62-
if k not in {'mcp_server', 'memory', 'unknown_builtin_tool'}
60+
k.kind for k in get_builtin_tool_types() if k not in {'mcp_server', 'memory', 'unknown_builtin_tool'}
6361
)
6462

6563

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,9 @@ def to_web(
14881488
| dict[str, models.Model | models.KnownModelName | str]
14891489
| None = None,
14901490
builtin_tools: list[AbstractBuiltinTool] | None = None,
1491+
deps: AgentDepsT = None,
1492+
model_settings: ModelSettings | None = None,
1493+
instructions: str | None = None,
14911494
) -> Starlette:
14921495
"""Create a Starlette app that serves a web chat UI for this agent.
14931496
@@ -1498,16 +1501,22 @@ def to_web(
14981501
The returned Starlette application can be mounted into a FastAPI app or run directly
14991502
with any ASGI server (uvicorn, hypercorn, etc.).
15001503
1504+
Note that the `deps` and `model_settings` will be the same for each request.
1505+
To provide different `deps` for each request use the lower-level adapters directly.
1506+
15011507
Args:
15021508
models: Models to make available in the UI. Can be:
15031509
- A sequence of model names/instances (e.g., `['openai:gpt-5', 'anthropic:claude-sonnet-4-5']`)
15041510
- A dict mapping display labels to model names/instances
15051511
(e.g., `{'GPT 5': 'openai:gpt-5', 'Claude': 'anthropic:claude-sonnet-4-5'}`)
1506-
If not provided, the UI will have no model options.
1512+
If not provided, uses the agent's configured model.
15071513
Builtin tool support is automatically determined from each model's profile.
1508-
builtin_tools: Builtin tools to make available. If not provided, no tools
1509-
will be available. Tool labels in the UI are derived from the tool's
1510-
`label` property.
1514+
builtin_tools: Builtin tools to make available. If not provided, uses the
1515+
agent's configured builtin tools. Tool labels in the UI are derived
1516+
from the tool's `label` property.
1517+
deps: Optional dependencies to use for all requests.
1518+
model_settings: Optional settings to use for all model requests.
1519+
instructions: Optional extra instructions to pass to each agent run.
15111520
15121521
Returns:
15131522
A configured Starlette application ready to be served (e.g., with uvicorn)
@@ -1517,30 +1526,32 @@ def to_web(
15171526
from pydantic_ai import Agent
15181527
from pydantic_ai.builtin_tools import WebSearchTool
15191528
1520-
agent = Agent('openai:gpt-5')
1529+
agent = Agent('openai:gpt-5', builtin_tools=[WebSearchTool()])
15211530
1522-
@agent.tool_plain
1523-
def get_weather(city: str) -> str:
1524-
return f'The weather in {city} is sunny'
1525-
1526-
# With model names (display names auto-generated)
1527-
app = agent.to_web(
1528-
models=['openai:gpt-5', 'anthropic:claude-sonnet-4-5'],
1529-
builtin_tools=[WebSearchTool()],
1530-
)
1531+
# Simple usage - uses agent's model and builtin tools
1532+
app = agent.to_web()
15311533
1532-
# Or with custom display labels
1533-
app = agent.to_web(
1534-
models={'GPT 5': 'openai:gpt-5', 'Claude': 'anthropic:claude-sonnet-4-5'},
1535-
builtin_tools=[WebSearchTool()],
1536-
)
1534+
# Or provide additional models for UI selection
1535+
app = agent.to_web(models=['openai:gpt-5', 'anthropic:claude-sonnet-4-5'])
15371536
15381537
# Then run with: uvicorn app:app --reload
15391538
```
15401539
"""
1541-
from ..ui._web import create_web_app
1540+
from ..ui._web import ModelsParam, create_web_app
1541+
1542+
# weird ternary for typing purposes
1543+
resolved_models: ModelsParam = models or (self._model and [self._model])
1544+
1545+
resolved_builtin_tools = builtin_tools or list(self._builtin_tools)
15421546

1543-
return create_web_app(self, models=models, builtin_tools=builtin_tools)
1547+
return create_web_app(
1548+
self,
1549+
models=resolved_models,
1550+
builtin_tools=resolved_builtin_tools,
1551+
deps=deps,
1552+
model_settings=model_settings,
1553+
instructions=instructions,
1554+
)
15441555

15451556
@asynccontextmanager
15461557
@deprecated(

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,13 @@ def prepare_request(
489489
# Check if builtin tools are supported
490490
if params.builtin_tools:
491491
supported_types = self.profile.supported_builtin_tools
492-
for tool in params.builtin_tools:
493-
if not isinstance(tool, tuple(supported_types)):
494-
raise UserError(
495-
f'Builtin tool {type(tool).__name__} is not supported by this model. '
496-
f'Supported tools: {[t.__name__ for t in supported_types]}'
497-
)
492+
unsupported = [tool for tool in params.builtin_tools if not isinstance(tool, tuple(supported_types))]
493+
if unsupported:
494+
unsupported_names = [type(tool).__name__ for tool in unsupported]
495+
supported_names = [t.__name__ for t in supported_types]
496+
raise UserError(
497+
f'Builtin tool(s) {unsupported_names} not supported by this model. Supported: {supported_names}'
498+
)
498499

499500
return model_settings, params
500501

@@ -506,8 +507,34 @@ def model_name(self) -> str:
506507

507508
@property
508509
def label(self) -> str:
509-
"""Human-friendly display label for the model."""
510-
return _format_model_label(self.model_name)
510+
"""Human-friendly display label for the model.
511+
512+
Handles common patterns:
513+
- gpt-5 -> GPT 5
514+
- claude-sonnet-4-5 -> Claude Sonnet 4.5
515+
- gemini-2.5-pro -> Gemini 2.5 Pro
516+
- meta-llama/llama-3-70b -> Llama 3 70b (OpenRouter style)
517+
"""
518+
label = self.model_name
519+
# Handle OpenRouter-style names with / (e.g., meta-llama/llama-3-70b)
520+
if '/' in label:
521+
label = label.split('/')[-1]
522+
523+
parts = label.split('-')
524+
result: list[str] = []
525+
526+
for i, part in enumerate(parts):
527+
if i == 0 and part.lower() == 'gpt':
528+
result.append(part.upper())
529+
elif part.replace('.', '').isdigit():
530+
if result and result[-1].replace('.', '').isdigit():
531+
result[-1] = f'{result[-1]}.{part}'
532+
else:
533+
result.append(part)
534+
else:
535+
result.append(part.capitalize())
536+
537+
return ' '.join(result)
511538

512539
@classmethod
513540
def supported_builtin_tools(cls) -> frozenset[type[AbstractBuiltinTool]]:
@@ -756,36 +783,6 @@ def timestamp(self) -> datetime:
756783
"""
757784

758785

759-
def _format_model_label(model_name: str) -> str:
760-
"""Format model name for display in UI.
761-
762-
Handles common patterns:
763-
- gpt-5 -> GPT 5
764-
- claude-sonnet-4-5 -> Claude Sonnet 4.5
765-
- gemini-2.5-pro -> Gemini 2.5 Pro
766-
- meta-llama/llama-3-70b -> Llama 3 70b (OpenRouter style)
767-
"""
768-
# Handle OpenRouter-style names with / (e.g., meta-llama/llama-3-70b)
769-
if '/' in model_name:
770-
model_name = model_name.split('/')[-1]
771-
772-
parts = model_name.split('-')
773-
result: list[str] = []
774-
775-
for i, part in enumerate(parts):
776-
if i == 0 and part.lower() == 'gpt':
777-
result.append(part.upper())
778-
elif part.replace('.', '').isdigit():
779-
if result and result[-1].replace('.', '').isdigit():
780-
result[-1] = f'{result[-1]}.{part}'
781-
else:
782-
result.append(part)
783-
else:
784-
result.append(part.capitalize())
785-
786-
return ' '.join(result)
787-
788-
789786
def check_allow_model_requests() -> None:
790787
"""Check if model requests are allowed.
791788

pydantic_ai_slim/pydantic_ai/ui/_web/api.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""API routes for the web chat UI."""
22

33
from collections.abc import Sequence
4-
from typing import Any
4+
from typing import TypeVar
55

66
from pydantic import BaseModel
77
from pydantic.alias_generators import to_camel
@@ -11,9 +11,13 @@
1111

1212
from pydantic_ai import Agent
1313
from pydantic_ai.builtin_tools import AbstractBuiltinTool
14+
from pydantic_ai.settings import ModelSettings
1415
from pydantic_ai.toolsets import AbstractToolset
1516
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
1617

18+
AgentDepsT = TypeVar('AgentDepsT')
19+
OutputDataT = TypeVar('OutputDataT')
20+
1721

1822
class ModelInfo(BaseModel, alias_generator=to_camel, populate_by_name=True):
1923
"""Defines an AI model with its associated built-in tools."""
@@ -30,14 +34,14 @@ class BuiltinToolInfo(BaseModel, alias_generator=to_camel, populate_by_name=True
3034
name: str
3135

3236

33-
class _ConfigureFrontend(BaseModel, alias_generator=to_camel, populate_by_name=True):
37+
class ConfigureFrontend(BaseModel, alias_generator=to_camel, populate_by_name=True):
3438
"""Response model for frontend configuration."""
3539

3640
models: list[ModelInfo]
3741
builtin_tools: list[BuiltinToolInfo]
3842

3943

40-
class _ChatRequestExtra(BaseModel, extra='ignore', alias_generator=to_camel):
44+
class ChatRequestExtra(BaseModel, extra='ignore', alias_generator=to_camel):
4145
"""Extra data extracted from chat request."""
4246

4347
model: str | None = None
@@ -46,12 +50,39 @@ class _ChatRequestExtra(BaseModel, extra='ignore', alias_generator=to_camel):
4650
"""Tool IDs selected by the user, e.g. ['web_search', 'code_execution']. Maps to JSON field 'builtinTools'."""
4751

4852

53+
def validate_request_options(
54+
extra_data: ChatRequestExtra,
55+
model_ids: set[str],
56+
allowed_tool_ids: set[str],
57+
) -> str | None:
58+
"""Validate that requested model and tools are in the allowed lists.
59+
60+
Returns an error message if validation fails, or None if valid.
61+
"""
62+
if extra_data.model and model_ids and extra_data.model not in model_ids:
63+
return f'Model "{extra_data.model}" is not in the allowed models list'
64+
65+
# base model also valdiates this but makes sesne to have an api check, since one could be a UI bug/misbehavior
66+
# the other would be a pydantic-ai bug
67+
# also as future proofing since we don't know how users will use this feature in the future
68+
invalid_tools = [t for t in extra_data.builtin_tools if t not in allowed_tool_ids]
69+
if invalid_tools:
70+
return f'Builtin tool(s) {invalid_tools} not in the allowed tools list'
71+
72+
return None
73+
74+
75+
# TODO remove the app arg and return a router instead (refactor the upstream logic to mount the router)
76+
# https://github.com/pydantic/pydantic-ai/pull/3456/files#r2582659204
4977
def add_api_routes(
5078
app: Starlette,
51-
agent: Agent,
79+
agent: Agent[AgentDepsT, OutputDataT],
5280
models: list[ModelInfo] | None = None,
5381
builtin_tools: list[AbstractBuiltinTool] | None = None,
54-
toolsets: Sequence[AbstractToolset[Any]] | None = None,
82+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
83+
deps: AgentDepsT = None,
84+
model_settings: ModelSettings | None = None,
85+
instructions: str | None = None,
5586
) -> None:
5687
"""Add API routes to a Starlette app.
5788
@@ -61,22 +92,25 @@ def add_api_routes(
6192
models: Optional list of AI models. If not provided, the UI will have no model options.
6293
builtin_tools: Optional list of builtin tools. If not provided, no tools will be available.
6394
toolsets: Optional list of toolsets (e.g., MCP servers). These provide additional tools.
95+
deps: Optional dependencies to use for all requests.
96+
model_settings: Optional settings to use for all model requests.
97+
instructions: Optional extra instructions to pass to each agent run.
6498
"""
65-
_models = models or []
66-
_model_ids = {m.id for m in _models}
67-
_builtin_tools = builtin_tools or []
68-
_toolsets = list(toolsets) if toolsets else None
69-
_tools_by_id: dict[str, AbstractBuiltinTool] = {tool.unique_id: tool for tool in _builtin_tools}
99+
models = models or []
100+
model_ids = {m.id for m in models}
101+
builtin_tools = builtin_tools or []
102+
allowed_tool_ids = {tool.unique_id for tool in builtin_tools}
103+
toolsets = list(toolsets) if toolsets else None
70104

71105
async def options_chat(request: Request) -> Response:
72106
"""Handle CORS preflight requests."""
73107
return Response()
74108

75109
async def configure_frontend(request: Request) -> Response:
76110
"""Endpoint to configure the frontend with available models and tools."""
77-
config = _ConfigureFrontend(
78-
models=_models,
79-
builtin_tools=[BuiltinToolInfo(id=tool.unique_id, name=tool.label) for tool in _builtin_tools],
111+
config = ConfigureFrontend(
112+
models=models,
113+
builtin_tools=[BuiltinToolInfo(id=tool.unique_id, name=tool.label) for tool in builtin_tools],
80114
)
81115
return JSONResponse(config.model_dump(by_alias=True))
82116

@@ -86,25 +120,22 @@ async def health(request: Request) -> Response:
86120

87121
async def post_chat(request: Request) -> Response:
88122
"""Handle chat requests via Vercel AI Adapter."""
89-
adapter = await VercelAIAdapter.from_request(request, agent=agent)
90-
extra_data = _ChatRequestExtra.model_validate(adapter.run_input.__pydantic_extra__)
91-
92-
# Validate model is in allowed list
93-
if extra_data.model and _model_ids and extra_data.model not in _model_ids:
94-
return JSONResponse(
95-
{'error': f'Model "{extra_data.model}" is not in the allowed models list'},
96-
status_code=400,
97-
)
98-
99-
request_builtin_tools = [
100-
_tools_by_id[tool_id] for tool_id in extra_data.builtin_tools if tool_id in _tools_by_id
101-
]
102-
streaming_response = await VercelAIAdapter.dispatch_request(
123+
adapter = await VercelAIAdapter[AgentDepsT, OutputDataT].from_request(request, agent=agent)
124+
extra_data = ChatRequestExtra.model_validate(adapter.run_input.__pydantic_extra__)
125+
126+
if error := validate_request_options(extra_data, model_ids, allowed_tool_ids):
127+
return JSONResponse({'error': error}, status_code=400)
128+
129+
request_builtin_tools = [tool for tool in builtin_tools if tool.unique_id in extra_data.builtin_tools]
130+
streaming_response = await VercelAIAdapter[AgentDepsT, OutputDataT].dispatch_request(
103131
request,
104132
agent=agent,
105133
model=extra_data.model,
106134
builtin_tools=request_builtin_tools,
107-
toolsets=_toolsets,
135+
toolsets=toolsets,
136+
deps=deps,
137+
model_settings=model_settings,
138+
instructions=instructions,
108139
)
109140
return streaming_response
110141

0 commit comments

Comments
 (0)