Skip to content

Commit b3e4832

Browse files
feat: Agent to OpenAI endpoint router
Co-authored-by: Ion Koutsouris <[email protected]>
1 parent 063278e commit b3e4832

File tree

19 files changed

+6762
-4213
lines changed

19 files changed

+6762
-4213
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
2+
from pydantic_ai.fastapi.registry import AgentRegistry
3+
4+
__all__ = [
5+
'AgentRegistry',
6+
'AgentAPIRouter',
7+
]
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import logging
2+
from typing import Any
3+
4+
try:
5+
from fastapi import APIRouter, HTTPException
6+
from fastapi.responses import StreamingResponse
7+
except ImportError as _import_error:
8+
raise ImportError(
9+
'Please install the `fastapi` package to use the fastapi router. '
10+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
11+
)
12+
try:
13+
from openai.types import ErrorObject
14+
from openai.types.chat.chat_completion import ChatCompletion
15+
from openai.types.model import Model
16+
from openai.types.responses import Response
17+
except ImportError as _import_error: # pragma: no cover
18+
raise ImportError(
19+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
20+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
21+
) from _import_error
22+
23+
from pydantic_ai.fastapi.api import AgentChatCompletionsAPI, AgentModelsAPI, AgentResponsesAPI
24+
from pydantic_ai.fastapi.data_models import (
25+
ChatCompletionRequest,
26+
ErrorResponse,
27+
ModelsResponse,
28+
ResponsesRequest,
29+
)
30+
from pydantic_ai.fastapi.registry import AgentRegistry
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
class AgentAPIRouter(APIRouter):
36+
"""FastAPI Router for Pydantic Agent."""
37+
38+
def __init__(
39+
self,
40+
agent_registry: AgentRegistry,
41+
disable_response_api: bool = False,
42+
disable_completions_api: bool = False,
43+
*args: tuple[Any],
44+
**kwargs: tuple[Any],
45+
):
46+
super().__init__(*args, **kwargs)
47+
self.registry = agent_registry
48+
self.responses_api = AgentResponsesAPI(self.registry)
49+
self.completions_api = AgentChatCompletionsAPI(self.registry)
50+
self.models_api = AgentModelsAPI(self.registry)
51+
self.enable_responses_api = not disable_response_api
52+
self.enable_completions_api = not disable_completions_api
53+
54+
# Registers OpenAI/v1 API routes
55+
self._register_routes()
56+
57+
def _register_routes(self) -> None: # noqa: C901
58+
if self.enable_completions_api:
59+
60+
@self.post(
61+
'/v1/chat/completions',
62+
response_model=ChatCompletion,
63+
)
64+
async def chat_completions( # type: ignore
65+
request: ChatCompletionRequest,
66+
) -> ChatCompletion | StreamingResponse:
67+
if not request.messages:
68+
raise HTTPException(
69+
status_code=400,
70+
detail=ErrorResponse(
71+
error=ErrorObject(
72+
type='invalid_request_error',
73+
message='Messages cannot be empty',
74+
),
75+
).model_dump(),
76+
)
77+
try:
78+
if getattr(request, 'stream', False):
79+
return StreamingResponse(
80+
self.completions_api.create_streaming_completion(request),
81+
media_type='text/event-stream',
82+
headers={
83+
'Cache-Control': 'no-cache',
84+
'Connection': 'keep-alive',
85+
'Content-Type': 'text/plain; charset=utf-8',
86+
},
87+
)
88+
else:
89+
return await self.completions_api.create_completion(request)
90+
except Exception as e:
91+
logger.error(f'Error in chat completion: {e}', exc_info=True)
92+
raise HTTPException(
93+
status_code=500,
94+
detail=ErrorResponse(
95+
error=ErrorObject(
96+
type='internal_server_error',
97+
message=str(e),
98+
),
99+
).model_dump(),
100+
)
101+
102+
if self.enable_responses_api:
103+
104+
@self.post(
105+
'/v1/responses',
106+
response_model=Response,
107+
)
108+
async def responses( # type: ignore
109+
request: ResponsesRequest,
110+
) -> Response:
111+
if not request.input:
112+
raise HTTPException(
113+
status_code=400,
114+
detail=ErrorResponse(
115+
error=ErrorObject(
116+
type='invalid_request_error',
117+
message='Messages cannot be empty',
118+
),
119+
).model_dump(),
120+
)
121+
try:
122+
if getattr(request, 'stream', False):
123+
# TODO: add streaming support for responses api
124+
raise HTTPException(status_code=501)
125+
else:
126+
return await self.responses_api.create_response(request)
127+
except Exception as e:
128+
logger.error(f'Error in responses: {e}', exc_info=True)
129+
raise HTTPException(
130+
status_code=500,
131+
detail=ErrorResponse(
132+
error=ErrorObject(
133+
type='internal_server_error',
134+
message=str(e),
135+
),
136+
).model_dump(),
137+
)
138+
139+
@self.get('/v1/models', response_model=ModelsResponse)
140+
async def get_models() -> ModelsResponse: # type: ignore
141+
try:
142+
return await self.models_api.list_models()
143+
except Exception as e:
144+
logger.error(f'Error listing models: {e}', exc_info=True)
145+
raise HTTPException(
146+
status_code=500,
147+
detail=ErrorResponse(
148+
error=ErrorObject(
149+
type='internal_server_error',
150+
message=f'Error retrieving models: {str(e)}',
151+
),
152+
).model_dump(),
153+
)
154+
155+
@self.get('/v1/models' + '/{model_id}', response_model=Model)
156+
async def get_model(model_id: str) -> Model: # type: ignore
157+
try:
158+
return await self.models_api.get_model(model_id)
159+
except HTTPException:
160+
raise
161+
except Exception as e:
162+
logger.error(f'Error fetching model info: {e}', exc_info=True)
163+
raise HTTPException(
164+
status_code=500,
165+
detail=ErrorResponse(
166+
error=ErrorObject(
167+
type='internal_server_error',
168+
message=f'Error retrieving model: {str(e)}',
169+
),
170+
).model_dump(),
171+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic_ai.fastapi.api.completions import AgentChatCompletionsAPI
2+
from pydantic_ai.fastapi.api.models import AgentModelsAPI
3+
from pydantic_ai.fastapi.api.responses import AgentResponsesAPI
4+
5+
__all__ = [
6+
'AgentChatCompletionsAPI',
7+
'AgentModelsAPI',
8+
'AgentResponsesAPI',
9+
]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import json
2+
import logging
3+
import time
4+
from collections.abc import AsyncGenerator
5+
from typing import Any
6+
7+
try:
8+
from fastapi import HTTPException
9+
except ImportError as _import_error:
10+
raise ImportError(
11+
'Please install the `fastapi` package to use the fastapi router. '
12+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
13+
)
14+
try:
15+
from openai.types import ErrorObject
16+
from openai.types.chat.chat_completion import ChatCompletion
17+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice as Chunkhoice, ChoiceDelta
18+
except ImportError as _import_error: # pragma: no cover
19+
raise ImportError(
20+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
21+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
22+
) from _import_error
23+
24+
from pydantic import TypeAdapter
25+
26+
from pydantic_ai import Agent, _utils
27+
from pydantic_ai.fastapi.convert import (
28+
openai_chat_completions_2pai,
29+
pai_result_to_openai_completions,
30+
)
31+
from pydantic_ai.fastapi.data_models import ChatCompletionRequest, ErrorResponse
32+
from pydantic_ai.fastapi.registry import AgentRegistry
33+
from pydantic_ai.settings import ModelSettings
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
class AgentChatCompletionsAPI:
39+
"""Chat completions API openai <-> pydantic-ai conversion."""
40+
41+
def __init__(self, registry: AgentRegistry) -> None:
42+
self.registry = registry
43+
44+
def get_agent(self, name: str) -> Agent:
45+
"""Retrieves agent."""
46+
try:
47+
agent = self.registry.get_completions_agent(name)
48+
except KeyError:
49+
raise HTTPException(
50+
status_code=404,
51+
detail=ErrorResponse(
52+
error=ErrorObject(
53+
message=f'Model {name} is not available as chat completions API',
54+
type='not_found_error',
55+
),
56+
).model_dump(),
57+
)
58+
59+
return agent
60+
61+
async def create_completion(self, request: ChatCompletionRequest) -> ChatCompletion:
62+
"""Create a non-streaming chat completion."""
63+
model_name = request.model
64+
agent = self.get_agent(model_name)
65+
66+
model_settings_ta = TypeAdapter(ModelSettings)
67+
messages = openai_chat_completions_2pai(messages=request.messages)
68+
69+
try:
70+
async with agent:
71+
result = await agent.run(
72+
message_history=messages,
73+
model_settings=model_settings_ta.validate_python(
74+
{k: v for k, v in request.model_dump().items() if v is not None},
75+
),
76+
)
77+
78+
return pai_result_to_openai_completions(
79+
result=result,
80+
model=model_name,
81+
)
82+
83+
except Exception as e:
84+
logger.error(f'Error creating completion: {e}')
85+
raise
86+
87+
async def create_streaming_completion(self, request: ChatCompletionRequest) -> AsyncGenerator[str]:
88+
"""Create a streaming chat completion."""
89+
model_name = request.model
90+
agent = self.get_agent(model_name)
91+
messages = openai_chat_completions_2pai(messages=request.messages)
92+
93+
role_sent = False
94+
95+
async with (
96+
agent,
97+
agent.run_stream(
98+
message_history=messages,
99+
) as result,
100+
):
101+
async for chunk in result.stream_text(delta=True):
102+
delta = ChoiceDelta(
103+
role='assistant' if not role_sent else None,
104+
content=chunk,
105+
)
106+
role_sent = True
107+
108+
stream_response = ChatCompletionChunk(
109+
id=f'chatcmpl-{_utils.now_utc().isoformat()}',
110+
created=int(_utils.now_utc().timestamp()),
111+
model=model_name,
112+
object='chat.completion.chunk',
113+
choices=[
114+
Chunkhoice(
115+
index=0,
116+
delta=delta,
117+
),
118+
],
119+
)
120+
121+
yield f'data: {stream_response.model_dump_json()}\n\n'
122+
123+
final_chunk: dict[str, Any] = {
124+
'id': f'chatcmpl-{int(time.time())}',
125+
'object': 'chat.completion.chunk',
126+
'model': model_name,
127+
'choices': [
128+
{
129+
'index': 0,
130+
'delta': {},
131+
'finish_reason': 'stop',
132+
},
133+
],
134+
}
135+
yield f'data: {json.dumps(final_chunk)}\n\n'
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import logging
2+
import time
3+
4+
try:
5+
from fastapi import HTTPException
6+
except ImportError as _import_error:
7+
raise ImportError(
8+
'Please install the `fastapi` package to use the fastapi router. '
9+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
10+
)
11+
try:
12+
from openai.types import ErrorObject
13+
from openai.types.model import Model
14+
except ImportError as _import_error: # pragma: no cover
15+
raise ImportError(
16+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
17+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
18+
) from _import_error
19+
20+
from pydantic_ai.fastapi.data_models import (
21+
ErrorResponse,
22+
ModelsResponse,
23+
)
24+
from pydantic_ai.fastapi.registry import AgentRegistry
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class AgentModelsAPI:
30+
"""Models API for pydantic-ai agents."""
31+
32+
def __init__(self, registry: AgentRegistry) -> None:
33+
self.registry = registry
34+
35+
async def list_models(self) -> ModelsResponse:
36+
"""List available models (OpenAI-compatible endpoint)."""
37+
agents = self.registry.all_agents
38+
39+
models = [
40+
Model(
41+
id=name,
42+
object='model',
43+
created=int(time.time()),
44+
owned_by='model_owner',
45+
)
46+
for name in agents
47+
]
48+
return ModelsResponse(data=models)
49+
50+
async def get_model(self, name: str) -> Model:
51+
"""Get information about a specific model (OpenAI-compatible endpoint)."""
52+
if name in self.registry.all_agents:
53+
return Model(id=name, object='model', created=int(time.time()), owned_by='NDIA')
54+
else:
55+
raise HTTPException(
56+
status_code=404,
57+
detail=ErrorResponse(
58+
error=ErrorObject(
59+
type='not_found_error',
60+
message=f"Model '{name}' not found",
61+
),
62+
).model_dump(),
63+
)

0 commit comments

Comments
 (0)