Skip to content

Commit e7f12d7

Browse files
committed
refactor: create agent router factory
Signed-off-by: Ion Koutsouris <[email protected]>
1 parent 67d47cd commit e7f12d7

File tree

4 files changed

+59
-66
lines changed

4 files changed

+59
-66
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
1+
from pydantic_ai.fastapi.agent_router import create_agent_router
22
from pydantic_ai.fastapi.registry import AgentRegistry
33

44
__all__ = [
55
'AgentRegistry',
6-
'AgentAPIRouter',
6+
'create_agent_router',
77
]
Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
try:
42
from fastapi import APIRouter, HTTPException
53
from fastapi.responses import StreamingResponse
@@ -21,64 +19,58 @@
2119
from pydantic_ai.fastapi.registry import AgentRegistry
2220

2321

24-
class AgentAPIRouter(APIRouter):
25-
"""FastAPI Router for Pydantic Agent."""
26-
27-
def __init__(
28-
self,
29-
agent_registry: AgentRegistry,
30-
disable_response_api: bool = False,
31-
disable_completions_api: bool = False,
32-
*args: Any,
33-
**kwargs: Any,
34-
):
35-
super().__init__(*args, **kwargs)
36-
self.registry = agent_registry
37-
self.responses_api = AgentResponsesAPI(self.registry)
38-
self.completions_api = AgentChatCompletionsAPI(self.registry)
39-
self.models_api = AgentModelsAPI(self.registry)
40-
self.enable_responses_api = not disable_response_api
41-
self.enable_completions_api = not disable_completions_api
22+
def create_agent_router(
23+
agent_registry: AgentRegistry,
24+
disable_responses_api: bool = False,
25+
disable_completions_api: bool = False,
26+
api_router: APIRouter | None = None,
27+
) -> APIRouter:
28+
"""FastAPI Router factory for Pydantic Agent exposure as OpenAI endpoint."""
29+
if api_router is None:
30+
api_router = APIRouter()
31+
responses_api = AgentResponsesAPI(agent_registry)
32+
completions_api = AgentChatCompletionsAPI(agent_registry)
33+
models_api = AgentModelsAPI(agent_registry)
34+
enable_responses_api = not disable_responses_api
35+
enable_completions_api = not disable_completions_api
4236

43-
# Registers OpenAI/v1 API routes
44-
self._register_routes()
37+
if enable_completions_api:
4538

46-
def _register_routes(self) -> None:
47-
if self.enable_completions_api:
39+
@api_router.post('/v1/chat/completions', response_model=ChatCompletion)
40+
async def chat_completions( # type: ignore[reportUnusedFunction]
41+
request: ChatCompletionRequest,
42+
) -> ChatCompletion | StreamingResponse:
43+
if getattr(request, 'stream', False):
44+
return StreamingResponse(
45+
completions_api.create_streaming_completion(request),
46+
media_type='text/event-stream',
47+
headers={
48+
'Cache-Control': 'no-cache',
49+
'Connection': 'keep-alive',
50+
'Content-Type': 'text/plain; charset=utf-8',
51+
},
52+
)
53+
else:
54+
return await completions_api.create_completion(request)
4855

49-
@self.post('/v1/chat/completions', response_model=ChatCompletion)
50-
async def chat_completions(
51-
request: ChatCompletionRequest,
52-
) -> ChatCompletion | StreamingResponse:
53-
if getattr(request, 'stream', False):
54-
return StreamingResponse(
55-
self.completions_api.create_streaming_completion(request),
56-
media_type='text/event-stream',
57-
headers={
58-
'Cache-Control': 'no-cache',
59-
'Connection': 'keep-alive',
60-
'Content-Type': 'text/plain; charset=utf-8',
61-
},
62-
)
63-
else:
64-
return await self.completions_api.create_completion(request)
56+
if enable_responses_api:
6557

66-
if self.enable_responses_api:
58+
@api_router.post('/v1/responses', response_model=OpenAIResponse)
59+
async def responses( # type: ignore[reportUnusedFunction]
60+
request: ResponsesRequest,
61+
) -> OpenAIResponse:
62+
if getattr(request, 'stream', False):
63+
# TODO: add streaming support for responses api
64+
raise HTTPException(status_code=501)
65+
else:
66+
return await responses_api.create_response(request)
6767

68-
@self.post('/v1/responses', response_model=OpenAIResponse)
69-
async def responses(
70-
request: ResponsesRequest,
71-
) -> OpenAIResponse:
72-
if getattr(request, 'stream', False):
73-
# TODO: add streaming support for responses api
74-
raise HTTPException(status_code=501)
75-
else:
76-
return await self.responses_api.create_response(request)
68+
@api_router.get('/v1/models', response_model=ModelsResponse)
69+
async def get_models() -> ModelsResponse: # type: ignore[reportUnusedFunction]
70+
return await models_api.list_models()
7771

78-
@self.get('/v1/models', response_model=ModelsResponse)
79-
async def get_models() -> ModelsResponse:
80-
return await self.models_api.list_models()
72+
@api_router.get('/v1/models' + '/{model_id}', response_model=Model)
73+
async def get_model(model_id: str) -> Model: # type: ignore[reportUnusedFunction]
74+
return await models_api.get_model(model_id)
8175

82-
@self.get('/v1/models' + '/{model_id}', response_model=Model)
83-
async def get_model(model_id: str) -> Model:
84-
return await self.models_api.get_model(model_id)
76+
return api_router

tests/agent_to_fastapi/integration_tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import pytest_asyncio
8+
from fastapi import APIRouter
89

910
from ...conftest import try_import
1011

@@ -14,7 +15,7 @@
1415
from openai import AsyncOpenAI, DefaultAioHttpClient
1516

1617
from pydantic_ai import Agent
17-
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
18+
from pydantic_ai.fastapi.agent_router import create_agent_router
1819
from pydantic_ai.fastapi.registry import AgentRegistry
1920
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel
2021
from pydantic_ai.providers.openai import OpenAIProvider
@@ -45,7 +46,7 @@ def app() -> FastAPI:
4546
registry.chat_completions_agents['test-model'] = cast(Any, object())
4647
registry.responses_agents['test-model'] = cast(Any, object())
4748

48-
router = AgentAPIRouter(agent_registry=registry)
49+
router = create_agent_router(agent_registry=registry)
4950

5051
for route in list(getattr(router, 'routes', [])):
5152
if getattr(route, 'path', None) == '/v1/responses':
@@ -61,7 +62,7 @@ def app() -> FastAPI:
6162

6263

6364
@pytest.fixture
64-
def agent_router(app: FastAPI) -> AgentAPIRouter:
65+
def agent_router(app: FastAPI) -> APIRouter:
6566
"""Return the AgentAPIRouter instance attached to the app by the `app` fixture.
6667
Tests can use this to stub `completions_api` and `responses_api` coroutine methods.
6768
"""

tests/agent_to_fastapi/integration_tests/test_api_integration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastapi.routing import APIRoute
1212
from httpx import ASGITransport, AsyncClient
1313

14-
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
14+
from pydantic_ai.fastapi.agent_router import create_agent_router
1515
from pydantic_ai.fastapi.registry import AgentRegistry
1616

1717

@@ -28,7 +28,7 @@ async def test_models_list_and_get(
2828
"""Verify model listing and retrieval endpoints behave as expected when real Agents are registered."""
2929
registry = registry_with_openai_clients
3030

31-
router = AgentAPIRouter(agent_registry=registry)
31+
router = create_agent_router(agent_registry=registry)
3232

3333
app = FastAPI()
3434
app.include_router(router)
@@ -70,7 +70,7 @@ async def test_routers_disabled(
7070
"""Verify whether disabling apis actually effectively not adds APIRoutes to the app."""
7171
registry = registry_with_openai_clients
7272

73-
router = AgentAPIRouter(agent_registry=registry, disable_completions_api=True, disable_response_api=True)
73+
router = create_agent_router(agent_registry=registry, disable_completions_api=True, disable_responses_api=True)
7474

7575
app = FastAPI()
7676
app.include_router(router)
@@ -128,7 +128,7 @@ async def test_chat_completions_e2e_with_mocked_openai(
128128
fake_openai_base = 'https://api.openai.test/v1'
129129
registry = registry_with_openai_clients
130130

131-
router = AgentAPIRouter(agent_registry=registry)
131+
router = create_agent_router(agent_registry=registry)
132132

133133
app = FastAPI()
134134
app.include_router(router)
@@ -199,7 +199,7 @@ async def test_responses_e2e_with_mocked_openai(
199199
fake_openai_base = 'https://api.openai.test/v1'
200200
registry = registry_with_openai_clients
201201

202-
router = AgentAPIRouter(agent_registry=registry)
202+
router = create_agent_router(agent_registry=registry)
203203

204204
# Disable response_model on the /v1/responses route so tests can return simple dicts if needed
205205
for route in list(getattr(router, 'routes', [])):

0 commit comments

Comments
 (0)