11"""API routes for the web chat UI."""
22
33from collections .abc import Sequence
4- from typing import Any
4+ from typing import TypeVar
55
66from pydantic import BaseModel
77from pydantic .alias_generators import to_camel
1111
1212from pydantic_ai import Agent
1313from pydantic_ai .builtin_tools import AbstractBuiltinTool
14+ from pydantic_ai .settings import ModelSettings
1415from pydantic_ai .toolsets import AbstractToolset
1516from pydantic_ai .ui .vercel_ai import VercelAIAdapter
1617
18+ AgentDepsT = TypeVar ('AgentDepsT' )
19+ OutputDataT = TypeVar ('OutputDataT' )
20+
1721
1822class ModelInfo (BaseModel , alias_generator = to_camel , populate_by_name = True ):
1923 """Defines an AI model with its associated built-in tools."""
@@ -30,14 +34,14 @@ class BuiltinToolInfo(BaseModel, alias_generator=to_camel, populate_by_name=True
3034 name : str
3135
3236
33- class _ConfigureFrontend (BaseModel , alias_generator = to_camel , populate_by_name = True ):
37+ class ConfigureFrontend (BaseModel , alias_generator = to_camel , populate_by_name = True ):
3438 """Response model for frontend configuration."""
3539
3640 models : list [ModelInfo ]
3741 builtin_tools : list [BuiltinToolInfo ]
3842
3943
40- class _ChatRequestExtra (BaseModel , extra = 'ignore' , alias_generator = to_camel ):
44+ class ChatRequestExtra (BaseModel , extra = 'ignore' , alias_generator = to_camel ):
4145 """Extra data extracted from chat request."""
4246
4347 model : str | None = None
@@ -46,12 +50,39 @@ class _ChatRequestExtra(BaseModel, extra='ignore', alias_generator=to_camel):
4650 """Tool IDs selected by the user, e.g. ['web_search', 'code_execution']. Maps to JSON field 'builtinTools'."""
4751
4852
53+ def validate_request_options (
54+ extra_data : ChatRequestExtra ,
55+ model_ids : set [str ],
56+ allowed_tool_ids : set [str ],
57+ ) -> str | None :
58+ """Validate that requested model and tools are in the allowed lists.
59+
60+ Returns an error message if validation fails, or None if valid.
61+ """
62+ if extra_data .model and model_ids and extra_data .model not in model_ids :
63+ return f'Model "{ extra_data .model } " is not in the allowed models list'
64+
65+ # base model also valdiates this but makes sesne to have an api check, since one could be a UI bug/misbehavior
66+ # the other would be a pydantic-ai bug
67+ # also as future proofing since we don't know how users will use this feature in the future
68+ invalid_tools = [t for t in extra_data .builtin_tools if t not in allowed_tool_ids ]
69+ if invalid_tools :
70+ return f'Builtin tool(s) { invalid_tools } not in the allowed tools list'
71+
72+ return None
73+
74+
75+ # TODO remove the app arg and return a router instead (refactor the upstream logic to mount the router)
76+ # https://github.com/pydantic/pydantic-ai/pull/3456/files#r2582659204
4977def add_api_routes (
5078 app : Starlette ,
51- agent : Agent ,
79+ agent : Agent [ AgentDepsT , OutputDataT ] ,
5280 models : list [ModelInfo ] | None = None ,
5381 builtin_tools : list [AbstractBuiltinTool ] | None = None ,
54- toolsets : Sequence [AbstractToolset [Any ]] | None = None ,
82+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
83+ deps : AgentDepsT = None ,
84+ model_settings : ModelSettings | None = None ,
85+ instructions : str | None = None ,
5586) -> None :
5687 """Add API routes to a Starlette app.
5788
@@ -61,22 +92,25 @@ def add_api_routes(
6192 models: Optional list of AI models. If not provided, the UI will have no model options.
6293 builtin_tools: Optional list of builtin tools. If not provided, no tools will be available.
6394 toolsets: Optional list of toolsets (e.g., MCP servers). These provide additional tools.
95+ deps: Optional dependencies to use for all requests.
96+ model_settings: Optional settings to use for all model requests.
97+ instructions: Optional extra instructions to pass to each agent run.
6498 """
65- _models = models or []
66- _model_ids = {m .id for m in _models }
67- _builtin_tools = builtin_tools or []
68- _toolsets = list ( toolsets ) if toolsets else None
69- _tools_by_id : dict [ str , AbstractBuiltinTool ] = { tool . unique_id : tool for tool in _builtin_tools }
99+ models = models or []
100+ model_ids = {m .id for m in models }
101+ builtin_tools = builtin_tools or []
102+ allowed_tool_ids = { tool . unique_id for tool in builtin_tools }
103+ toolsets = list ( toolsets ) if toolsets else None
70104
71105 async def options_chat (request : Request ) -> Response :
72106 """Handle CORS preflight requests."""
73107 return Response ()
74108
75109 async def configure_frontend (request : Request ) -> Response :
76110 """Endpoint to configure the frontend with available models and tools."""
77- config = _ConfigureFrontend (
78- models = _models ,
79- builtin_tools = [BuiltinToolInfo (id = tool .unique_id , name = tool .label ) for tool in _builtin_tools ],
111+ config = ConfigureFrontend (
112+ models = models ,
113+ builtin_tools = [BuiltinToolInfo (id = tool .unique_id , name = tool .label ) for tool in builtin_tools ],
80114 )
81115 return JSONResponse (config .model_dump (by_alias = True ))
82116
@@ -86,25 +120,22 @@ async def health(request: Request) -> Response:
86120
87121 async def post_chat (request : Request ) -> Response :
88122 """Handle chat requests via Vercel AI Adapter."""
89- adapter = await VercelAIAdapter .from_request (request , agent = agent )
90- extra_data = _ChatRequestExtra .model_validate (adapter .run_input .__pydantic_extra__ )
91-
92- # Validate model is in allowed list
93- if extra_data .model and _model_ids and extra_data .model not in _model_ids :
94- return JSONResponse (
95- {'error' : f'Model "{ extra_data .model } " is not in the allowed models list' },
96- status_code = 400 ,
97- )
98-
99- request_builtin_tools = [
100- _tools_by_id [tool_id ] for tool_id in extra_data .builtin_tools if tool_id in _tools_by_id
101- ]
102- streaming_response = await VercelAIAdapter .dispatch_request (
123+ adapter = await VercelAIAdapter [AgentDepsT , OutputDataT ].from_request (request , agent = agent )
124+ extra_data = ChatRequestExtra .model_validate (adapter .run_input .__pydantic_extra__ )
125+
126+ if error := validate_request_options (extra_data , model_ids , allowed_tool_ids ):
127+ return JSONResponse ({'error' : error }, status_code = 400 )
128+
129+ request_builtin_tools = [tool for tool in builtin_tools if tool .unique_id in extra_data .builtin_tools ]
130+ streaming_response = await VercelAIAdapter [AgentDepsT , OutputDataT ].dispatch_request (
103131 request ,
104132 agent = agent ,
105133 model = extra_data .model ,
106134 builtin_tools = request_builtin_tools ,
107- toolsets = _toolsets ,
135+ toolsets = toolsets ,
136+ deps = deps ,
137+ model_settings = model_settings ,
138+ instructions = instructions ,
108139 )
109140 return streaming_response
110141
0 commit comments