1- import logging
21from typing import Any
32
43try :
54 from fastapi import APIRouter , HTTPException
65 from fastapi .responses import StreamingResponse
7- from openai .types import ErrorObject
86 from openai .types .chat .chat_completion import ChatCompletion
97 from openai .types .model import Model
10- from openai .types .responses import Response
8+ from openai .types .responses import Response as OpenAIResponse
119except ImportError as _import_error : # pragma: no cover
1210 raise ImportError (
13- 'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
14- 'you can use the `openai` and `fastapi` optional group — `pip install "pydantic-ai-slim[openai,fastapi ]"`'
11+ 'Please install the `openai` and `fastapi` packages to enable the fastapi openai compatible endpoint, '
12+ 'you can use the `chat-completion` optional group — `pip install "pydantic-ai-slim[chat-completion ]"`'
1513 ) from _import_error
1614
1715from pydantic_ai .fastapi .api import AgentChatCompletionsAPI , AgentModelsAPI , AgentResponsesAPI
1816from pydantic_ai .fastapi .data_models import (
1917 ChatCompletionRequest ,
20- ErrorResponse ,
2118 ModelsResponse ,
2219 ResponsesRequest ,
2320)
2421from pydantic_ai .fastapi .registry import AgentRegistry
2522
26- logger = logging .getLogger (__name__ )
27-
2823
2924class AgentAPIRouter (APIRouter ):
3025 """FastAPI Router for Pydantic Agent."""
@@ -34,8 +29,8 @@ def __init__(
3429 agent_registry : AgentRegistry ,
3530 disable_response_api : bool = False ,
3631 disable_completions_api : bool = False ,
37- * args : tuple [ Any ] ,
38- ** kwargs : tuple [ Any ] ,
32+ * args : Any ,
33+ ** kwargs : Any ,
3934 ):
4035 super ().__init__ (* args , ** kwargs )
4136 self .registry = agent_registry
@@ -48,118 +43,42 @@ def __init__(
4843 # Registers OpenAI/v1 API routes
4944 self ._register_routes ()
5045
51- def _register_routes (self ) -> None : # noqa: C901
46+ def _register_routes (self ) -> None :
5247 if self .enable_completions_api :
5348
54- @self .post (
55- '/v1/chat/completions' ,
56- response_model = ChatCompletion ,
57- )
58- async def chat_completions ( # type: ignore
49+ @self .post ('/v1/chat/completions' , response_model = ChatCompletion )
50+ async def chat_completions (
5951 request : ChatCompletionRequest ,
6052 ) -> ChatCompletion | StreamingResponse :
61- if not request .messages :
62- raise HTTPException (
63- status_code = 400 ,
64- detail = ErrorResponse (
65- error = ErrorObject (
66- type = 'invalid_request_error' ,
67- message = 'Messages cannot be empty' ,
68- ),
69- ).model_dump (),
70- )
71- try :
72- if getattr (request , 'stream' , False ):
73- return StreamingResponse (
74- self .completions_api .create_streaming_completion (request ),
75- media_type = 'text/event-stream' ,
76- headers = {
77- 'Cache-Control' : 'no-cache' ,
78- 'Connection' : 'keep-alive' ,
79- 'Content-Type' : 'text/plain; charset=utf-8' ,
80- },
81- )
82- else :
83- return await self .completions_api .create_completion (request )
84- except Exception as e :
85- logger .error (f'Error in chat completion: { e } ' , exc_info = True )
86- raise HTTPException (
87- status_code = 500 ,
88- detail = ErrorResponse (
89- error = ErrorObject (
90- type = 'internal_server_error' ,
91- message = str (e ),
92- ),
93- ).model_dump (),
53+ if getattr (request , 'stream' , False ):
54+ return StreamingResponse (
55+ self .completions_api .create_streaming_completion (request ),
56+ media_type = 'text/event-stream' ,
57+ headers = {
58+ 'Cache-Control' : 'no-cache' ,
59+ 'Connection' : 'keep-alive' ,
60+ 'Content-Type' : 'text/plain; charset=utf-8' ,
61+ },
9462 )
63+ else :
64+ return await self .completions_api .create_completion (request )
9565
9666 if self .enable_responses_api :
9767
98- @self .post (
99- '/v1/responses' ,
100- response_model = Response ,
101- )
102- async def responses ( # type: ignore
68+ @self .post ('/v1/responses' , response_model = OpenAIResponse )
69+ async def responses (
10370 request : ResponsesRequest ,
104- ) -> Response :
105- if not request .input :
106- raise HTTPException (
107- status_code = 400 ,
108- detail = ErrorResponse (
109- error = ErrorObject (
110- type = 'invalid_request_error' ,
111- message = 'Messages cannot be empty' ,
112- ),
113- ).model_dump (),
114- )
115- try :
116- if getattr (request , 'stream' , False ):
117- # TODO: add streaming support for responses api
118- raise HTTPException (status_code = 501 )
119- else :
120- return await self .responses_api .create_response (request )
121- except Exception as e :
122- logger .error (f'Error in responses: { e } ' , exc_info = True )
123- raise HTTPException (
124- status_code = 500 ,
125- detail = ErrorResponse (
126- error = ErrorObject (
127- type = 'internal_server_error' ,
128- message = str (e ),
129- ),
130- ).model_dump (),
131- )
71+ ) -> OpenAIResponse :
72+ if getattr (request , 'stream' , False ):
73+ # TODO: add streaming support for responses api
74+ raise HTTPException (status_code = 501 )
75+ else :
76+ return await self .responses_api .create_response (request )
13277
13378 @self .get ('/v1/models' , response_model = ModelsResponse )
134- async def get_models () -> ModelsResponse : # type: ignore
135- try :
136- return await self .models_api .list_models ()
137- except Exception as e :
138- logger .error (f'Error listing models: { e } ' , exc_info = True )
139- raise HTTPException (
140- status_code = 500 ,
141- detail = ErrorResponse (
142- error = ErrorObject (
143- type = 'internal_server_error' ,
144- message = f'Error retrieving models: { str (e )} ' ,
145- ),
146- ).model_dump (),
147- )
79+ async def get_models () -> ModelsResponse :
80+ return await self .models_api .list_models ()
14881
14982 @self .get ('/v1/models' + '/{model_id}' , response_model = Model )
150- async def get_model (model_id : str ) -> Model : # type: ignore
151- try :
152- return await self .models_api .get_model (model_id )
153- except HTTPException :
154- raise
155- except Exception as e :
156- logger .error (f'Error fetching model info: { e } ' , exc_info = True )
157- raise HTTPException (
158- status_code = 500 ,
159- detail = ErrorResponse (
160- error = ErrorObject (
161- type = 'internal_server_error' ,
162- message = f'Error retrieving model: { str (e )} ' ,
163- ),
164- ).model_dump (),
165- )
83+ async def get_model (model_id : str ) -> Model :
84+ return await self .models_api .get_model (model_id )
0 commit comments