| 
 | 1 | +import logging  | 
 | 2 | +from typing import Any  | 
 | 3 | + | 
 | 4 | +from fastapi import APIRouter, HTTPException  | 
 | 5 | +from fastapi.responses import StreamingResponse  | 
 | 6 | +from openai.types import ErrorObject  | 
 | 7 | +from openai.types.chat.chat_completion import ChatCompletion  | 
 | 8 | +from openai.types.model import Model  | 
 | 9 | +from openai.types.responses import Response  | 
 | 10 | + | 
 | 11 | +from pydantic_ai.fastapi.api import AgentChatCompletionsAPI, AgentModelsAPI, AgentResponsesAPI  | 
 | 12 | +from pydantic_ai.fastapi.data_models import (  | 
 | 13 | +    ChatCompletionRequest,  | 
 | 14 | +    ErrorResponse,  | 
 | 15 | +    ModelsResponse,  | 
 | 16 | +    ResponsesRequest,  | 
 | 17 | +)  | 
 | 18 | +from pydantic_ai.fastapi.registry import AgentRegistry  | 
 | 19 | + | 
 | 20 | +logger = logging.getLogger(__name__)  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +class AgentAPIRouter(APIRouter):  | 
 | 24 | +    """FastAPI Router for Pydantic Agent."""  | 
 | 25 | + | 
 | 26 | +    def __init__(  | 
 | 27 | +        self,  | 
 | 28 | +        agent_registry: AgentRegistry,  | 
 | 29 | +        disable_response_api: bool = False,  | 
 | 30 | +        disable_completions_api: bool = False,  | 
 | 31 | +        *args: tuple[Any],  | 
 | 32 | +        **kwargs: tuple[Any],  | 
 | 33 | +    ):  | 
 | 34 | +        super().__init__(*args, **kwargs)  | 
 | 35 | +        self.registry = agent_registry  | 
 | 36 | +        self.responses_api = AgentResponsesAPI(self.registry)  | 
 | 37 | +        self.completions_api = AgentChatCompletionsAPI(self.registry)  | 
 | 38 | +        self.models_api = AgentModelsAPI(self.registry)  | 
 | 39 | +        self.enable_responses_api = not disable_response_api  | 
 | 40 | +        self.enable_completions_api = not disable_completions_api  | 
 | 41 | + | 
 | 42 | +        # Registers OpenAI/v1 API routes  | 
 | 43 | +        self._register_routes()  | 
 | 44 | + | 
 | 45 | +    def _register_routes(self) -> None:  # noqa: C901  | 
 | 46 | +        if self.enable_completions_api:  | 
 | 47 | + | 
 | 48 | +            @self.post(  | 
 | 49 | +                '/v1/chat/completions',  | 
 | 50 | +                response_model=ChatCompletion,  | 
 | 51 | +            )  | 
 | 52 | +            async def chat_completions(  # type: ignore  | 
 | 53 | +                request: ChatCompletionRequest,  | 
 | 54 | +            ) -> ChatCompletion | StreamingResponse:  | 
 | 55 | +                if not request.messages:  | 
 | 56 | +                    raise HTTPException(  | 
 | 57 | +                        status_code=400,  | 
 | 58 | +                        detail=ErrorResponse(  | 
 | 59 | +                            error=ErrorObject(  | 
 | 60 | +                                type='invalid_request_error',  | 
 | 61 | +                                message='Messages cannot be empty',  | 
 | 62 | +                            ),  | 
 | 63 | +                        ).model_dump(),  | 
 | 64 | +                    )  | 
 | 65 | +                try:  | 
 | 66 | +                    if getattr(request, 'stream', False):  | 
 | 67 | +                        return StreamingResponse(  | 
 | 68 | +                            self.completions_api.create_streaming_completion(request),  | 
 | 69 | +                            media_type='text/event-stream',  | 
 | 70 | +                            headers={  | 
 | 71 | +                                'Cache-Control': 'no-cache',  | 
 | 72 | +                                'Connection': 'keep-alive',  | 
 | 73 | +                                'Content-Type': 'text/plain; charset=utf-8',  | 
 | 74 | +                            },  | 
 | 75 | +                        )  | 
 | 76 | +                    else:  | 
 | 77 | +                        return await self.completions_api.create_completion(request)  | 
 | 78 | +                except Exception as e:  | 
 | 79 | +                    logger.error(f'Error in chat completion: {e}', exc_info=True)  | 
 | 80 | +                    raise HTTPException(  | 
 | 81 | +                        status_code=500,  | 
 | 82 | +                        detail=ErrorResponse(  | 
 | 83 | +                            error=ErrorObject(  | 
 | 84 | +                                type='internal_server_error',  | 
 | 85 | +                                message=str(e),  | 
 | 86 | +                            ),  | 
 | 87 | +                        ).model_dump(),  | 
 | 88 | +                    )  | 
 | 89 | + | 
 | 90 | +        if self.enable_responses_api:  | 
 | 91 | + | 
 | 92 | +            @self.post(  | 
 | 93 | +                '/v1/responses',  | 
 | 94 | +                response_model=Response,  | 
 | 95 | +            )  | 
 | 96 | +            async def responses(  # type: ignore  | 
 | 97 | +                request: ResponsesRequest,  | 
 | 98 | +            ) -> Response:  | 
 | 99 | +                if not request.input:  | 
 | 100 | +                    raise HTTPException(  | 
 | 101 | +                        status_code=400,  | 
 | 102 | +                        detail=ErrorResponse(  | 
 | 103 | +                            error=ErrorObject(  | 
 | 104 | +                                type='invalid_request_error',  | 
 | 105 | +                                message='Messages cannot be empty',  | 
 | 106 | +                            ),  | 
 | 107 | +                        ).model_dump(),  | 
 | 108 | +                    )  | 
 | 109 | +                try:  | 
 | 110 | +                    if getattr(request, 'stream', False):  | 
 | 111 | +                        # TODO: add streaming support for responses api  | 
 | 112 | +                        raise HTTPException(status_code=501)  | 
 | 113 | +                    else:  | 
 | 114 | +                        return await self.responses_api.create_response(request)  | 
 | 115 | +                except Exception as e:  | 
 | 116 | +                    logger.error(f'Error in responses: {e}', exc_info=True)  | 
 | 117 | +                    raise HTTPException(  | 
 | 118 | +                        status_code=500,  | 
 | 119 | +                        detail=ErrorResponse(  | 
 | 120 | +                            error=ErrorObject(  | 
 | 121 | +                                type='internal_server_error',  | 
 | 122 | +                                message=str(e),  | 
 | 123 | +                            ),  | 
 | 124 | +                        ).model_dump(),  | 
 | 125 | +                    )  | 
 | 126 | + | 
 | 127 | +        @self.get('/v1/models', response_model=ModelsResponse)  | 
 | 128 | +        async def get_models() -> ModelsResponse:  # type: ignore  | 
 | 129 | +            try:  | 
 | 130 | +                return await self.models_api.list_models()  | 
 | 131 | +            except Exception as e:  | 
 | 132 | +                logger.error(f'Error listing models: {e}', exc_info=True)  | 
 | 133 | +                raise HTTPException(  | 
 | 134 | +                    status_code=500,  | 
 | 135 | +                    detail=ErrorResponse(  | 
 | 136 | +                        error=ErrorObject(  | 
 | 137 | +                            type='internal_server_error',  | 
 | 138 | +                            message=f'Error retrieving models: {str(e)}',  | 
 | 139 | +                        ),  | 
 | 140 | +                    ).model_dump(),  | 
 | 141 | +                )  | 
 | 142 | + | 
 | 143 | +        @self.get('/v1/models' + '/{model_id}', response_model=Model)  | 
 | 144 | +        async def get_model(model_id: str) -> Model:  # type: ignore  | 
 | 145 | +            try:  | 
 | 146 | +                return await self.models_api.get_model(model_id)  | 
 | 147 | +            except HTTPException:  | 
 | 148 | +                raise  | 
 | 149 | +            except Exception as e:  | 
 | 150 | +                logger.error(f'Error fetching model info: {e}', exc_info=True)  | 
 | 151 | +                raise HTTPException(  | 
 | 152 | +                    status_code=500,  | 
 | 153 | +                    detail=ErrorResponse(  | 
 | 154 | +                        error=ErrorObject(  | 
 | 155 | +                            type='internal_server_error',  | 
 | 156 | +                            message=f'Error retrieving model: {str(e)}',  | 
 | 157 | +                        ),  | 
 | 158 | +                    ).model_dump(),  | 
 | 159 | +                )  | 
0 commit comments