Skip to content

Commit 50430ae

Browse files
feat: Agent to OpenAI endpoint router
Co-authored-by: Ion Koutsouris <[email protected]>
1 parent a58dd47 commit 50430ae

File tree

20 files changed

+6790
-4216
lines changed

20 files changed

+6790
-4216
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
2+
from pydantic_ai.fastapi.registry import AgentRegistry
3+
4+
__all__ = [
5+
'AgentRegistry',
6+
'AgentAPIRouter',
7+
]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import logging
2+
from typing import Any
3+
4+
try:
5+
from fastapi import APIRouter, HTTPException
6+
from fastapi.responses import StreamingResponse
7+
from openai.types import ErrorObject
8+
from openai.types.chat.chat_completion import ChatCompletion
9+
from openai.types.model import Model
10+
from openai.types.responses import Response
11+
except ImportError as _import_error: # pragma: no cover
12+
raise ImportError(
13+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
14+
'you can use the `openai` and `fastapi` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
15+
) from _import_error
16+
17+
from pydantic_ai.fastapi.api import AgentChatCompletionsAPI, AgentModelsAPI, AgentResponsesAPI
18+
from pydantic_ai.fastapi.data_models import (
19+
ChatCompletionRequest,
20+
ErrorResponse,
21+
ModelsResponse,
22+
ResponsesRequest,
23+
)
24+
from pydantic_ai.fastapi.registry import AgentRegistry
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class AgentAPIRouter(APIRouter):
30+
"""FastAPI Router for Pydantic Agent."""
31+
32+
def __init__(
33+
self,
34+
agent_registry: AgentRegistry,
35+
disable_response_api: bool = False,
36+
disable_completions_api: bool = False,
37+
*args: tuple[Any],
38+
**kwargs: tuple[Any],
39+
):
40+
super().__init__(*args, **kwargs)
41+
self.registry = agent_registry
42+
self.responses_api = AgentResponsesAPI(self.registry)
43+
self.completions_api = AgentChatCompletionsAPI(self.registry)
44+
self.models_api = AgentModelsAPI(self.registry)
45+
self.enable_responses_api = not disable_response_api
46+
self.enable_completions_api = not disable_completions_api
47+
48+
# Registers OpenAI/v1 API routes
49+
self._register_routes()
50+
51+
def _register_routes(self) -> None: # noqa: C901
52+
if self.enable_completions_api:
53+
54+
@self.post(
55+
'/v1/chat/completions',
56+
response_model=ChatCompletion,
57+
)
58+
async def chat_completions( # type: ignore
59+
request: ChatCompletionRequest,
60+
) -> ChatCompletion | StreamingResponse:
61+
if not request.messages:
62+
raise HTTPException(
63+
status_code=400,
64+
detail=ErrorResponse(
65+
error=ErrorObject(
66+
type='invalid_request_error',
67+
message='Messages cannot be empty',
68+
),
69+
).model_dump(),
70+
)
71+
try:
72+
if getattr(request, 'stream', False):
73+
return StreamingResponse(
74+
self.completions_api.create_streaming_completion(request),
75+
media_type='text/event-stream',
76+
headers={
77+
'Cache-Control': 'no-cache',
78+
'Connection': 'keep-alive',
79+
'Content-Type': 'text/plain; charset=utf-8',
80+
},
81+
)
82+
else:
83+
return await self.completions_api.create_completion(request)
84+
except Exception as e:
85+
logger.error(f'Error in chat completion: {e}', exc_info=True)
86+
raise HTTPException(
87+
status_code=500,
88+
detail=ErrorResponse(
89+
error=ErrorObject(
90+
type='internal_server_error',
91+
message=str(e),
92+
),
93+
).model_dump(),
94+
)
95+
96+
if self.enable_responses_api:
97+
98+
@self.post(
99+
'/v1/responses',
100+
response_model=Response,
101+
)
102+
async def responses( # type: ignore
103+
request: ResponsesRequest,
104+
) -> Response:
105+
if not request.input:
106+
raise HTTPException(
107+
status_code=400,
108+
detail=ErrorResponse(
109+
error=ErrorObject(
110+
type='invalid_request_error',
111+
message='Messages cannot be empty',
112+
),
113+
).model_dump(),
114+
)
115+
try:
116+
if getattr(request, 'stream', False):
117+
# TODO: add streaming support for responses api
118+
raise HTTPException(status_code=501)
119+
else:
120+
return await self.responses_api.create_response(request)
121+
except Exception as e:
122+
logger.error(f'Error in responses: {e}', exc_info=True)
123+
raise HTTPException(
124+
status_code=500,
125+
detail=ErrorResponse(
126+
error=ErrorObject(
127+
type='internal_server_error',
128+
message=str(e),
129+
),
130+
).model_dump(),
131+
)
132+
133+
@self.get('/v1/models', response_model=ModelsResponse)
134+
async def get_models() -> ModelsResponse: # type: ignore
135+
try:
136+
return await self.models_api.list_models()
137+
except Exception as e:
138+
logger.error(f'Error listing models: {e}', exc_info=True)
139+
raise HTTPException(
140+
status_code=500,
141+
detail=ErrorResponse(
142+
error=ErrorObject(
143+
type='internal_server_error',
144+
message=f'Error retrieving models: {str(e)}',
145+
),
146+
).model_dump(),
147+
)
148+
149+
@self.get('/v1/models' + '/{model_id}', response_model=Model)
150+
async def get_model(model_id: str) -> Model: # type: ignore
151+
try:
152+
return await self.models_api.get_model(model_id)
153+
except HTTPException:
154+
raise
155+
except Exception as e:
156+
logger.error(f'Error fetching model info: {e}', exc_info=True)
157+
raise HTTPException(
158+
status_code=500,
159+
detail=ErrorResponse(
160+
error=ErrorObject(
161+
type='internal_server_error',
162+
message=f'Error retrieving model: {str(e)}',
163+
),
164+
).model_dump(),
165+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic_ai.fastapi.api.completions import AgentChatCompletionsAPI
2+
from pydantic_ai.fastapi.api.models import AgentModelsAPI
3+
from pydantic_ai.fastapi.api.responses import AgentResponsesAPI
4+
5+
__all__ = [
6+
'AgentChatCompletionsAPI',
7+
'AgentModelsAPI',
8+
'AgentResponsesAPI',
9+
]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import json
2+
import logging
3+
import time
4+
from collections.abc import AsyncGenerator
5+
from typing import Any
6+
7+
try:
8+
from fastapi import HTTPException
9+
from openai.types import ErrorObject
10+
from openai.types.chat.chat_completion import ChatCompletion
11+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice as Chunkhoice, ChoiceDelta
12+
except ImportError as _import_error: # pragma: no cover
13+
raise ImportError(
14+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
15+
'you can use the `openai` and `fastapi` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
16+
) from _import_error
17+
18+
from pydantic import TypeAdapter
19+
20+
from pydantic_ai import Agent, _utils
21+
from pydantic_ai.fastapi.convert import (
22+
openai_chat_completions_2pai,
23+
pai_result_to_openai_completions,
24+
)
25+
from pydantic_ai.fastapi.data_models import ChatCompletionRequest, ErrorResponse
26+
from pydantic_ai.fastapi.registry import AgentRegistry
27+
from pydantic_ai.settings import ModelSettings
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class AgentChatCompletionsAPI:
33+
"""Chat completions API openai <-> pydantic-ai conversion."""
34+
35+
def __init__(self, registry: AgentRegistry) -> None:
36+
self.registry = registry
37+
38+
def get_agent(self, name: str) -> Agent:
39+
"""Retrieves agent."""
40+
try:
41+
agent = self.registry.get_completions_agent(name)
42+
except KeyError:
43+
raise HTTPException(
44+
status_code=404,
45+
detail=ErrorResponse(
46+
error=ErrorObject(
47+
message=f'Model {name} is not available as chat completions API',
48+
type='not_found_error',
49+
),
50+
).model_dump(),
51+
)
52+
53+
return agent
54+
55+
async def create_completion(self, request: ChatCompletionRequest) -> ChatCompletion:
56+
"""Create a non-streaming chat completion."""
57+
model_name = request.model
58+
agent = self.get_agent(model_name)
59+
60+
model_settings_ta = TypeAdapter(ModelSettings)
61+
messages = openai_chat_completions_2pai(messages=request.messages)
62+
63+
try:
64+
async with agent:
65+
result = await agent.run(
66+
message_history=messages,
67+
model_settings=model_settings_ta.validate_python(
68+
{k: v for k, v in request.model_dump().items() if v is not None},
69+
),
70+
)
71+
72+
return pai_result_to_openai_completions(
73+
result=result,
74+
model=model_name,
75+
)
76+
77+
except Exception as e:
78+
logger.error(f'Error creating completion: {e}')
79+
raise
80+
81+
async def create_streaming_completion(self, request: ChatCompletionRequest) -> AsyncGenerator[str]:
82+
"""Create a streaming chat completion."""
83+
model_name = request.model
84+
agent = self.get_agent(model_name)
85+
messages = openai_chat_completions_2pai(messages=request.messages)
86+
87+
role_sent = False
88+
89+
async with (
90+
agent,
91+
agent.run_stream(
92+
message_history=messages,
93+
) as result,
94+
):
95+
async for chunk in result.stream_text(delta=True):
96+
delta = ChoiceDelta(
97+
role='assistant' if not role_sent else None,
98+
content=chunk,
99+
)
100+
role_sent = True
101+
102+
stream_response = ChatCompletionChunk(
103+
id=f'chatcmpl-{_utils.now_utc().isoformat()}',
104+
created=int(_utils.now_utc().timestamp()),
105+
model=model_name,
106+
object='chat.completion.chunk',
107+
choices=[
108+
Chunkhoice(
109+
index=0,
110+
delta=delta,
111+
),
112+
],
113+
)
114+
115+
yield f'data: {stream_response.model_dump_json()}\n\n'
116+
117+
final_chunk: dict[str, Any] = {
118+
'id': f'chatcmpl-{int(time.time())}',
119+
'object': 'chat.completion.chunk',
120+
'model': model_name,
121+
'choices': [
122+
{
123+
'index': 0,
124+
'delta': {},
125+
'finish_reason': 'stop',
126+
},
127+
],
128+
}
129+
yield f'data: {json.dumps(final_chunk)}\n\n'
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import time
3+
4+
try:
5+
from fastapi import HTTPException
6+
from openai.types import ErrorObject
7+
from openai.types.model import Model
8+
except ImportError as _import_error: # pragma: no cover
9+
raise ImportError(
10+
'Please install the `openai` package to enable the fastapi openai compatible endpoint, '
11+
'you can use the `openai` and `fastapi` optional group — `pip install "pydantic-ai-slim[openai,fastapi]"`'
12+
) from _import_error
13+
14+
from pydantic_ai.fastapi.data_models import (
15+
ErrorResponse,
16+
ModelsResponse,
17+
)
18+
from pydantic_ai.fastapi.registry import AgentRegistry
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class AgentModelsAPI:
24+
"""Models API for pydantic-ai agents."""
25+
26+
def __init__(self, registry: AgentRegistry) -> None:
27+
self.registry = registry
28+
29+
async def list_models(self) -> ModelsResponse:
30+
"""List available models (OpenAI-compatible endpoint)."""
31+
agents = self.registry.all_agents
32+
33+
models = [
34+
Model(
35+
id=name,
36+
object='model',
37+
created=int(time.time()),
38+
owned_by='model_owner',
39+
)
40+
for name in agents
41+
]
42+
return ModelsResponse(data=models)
43+
44+
async def get_model(self, name: str) -> Model:
45+
"""Get information about a specific model (OpenAI-compatible endpoint)."""
46+
if name in self.registry.all_agents:
47+
return Model(id=name, object='model', created=int(time.time()), owned_by='NDIA')
48+
else:
49+
raise HTTPException(
50+
status_code=404,
51+
detail=ErrorResponse(
52+
error=ErrorObject(
53+
type='not_found_error',
54+
message=f"Model '{name}' not found",
55+
),
56+
).model_dump(),
57+
)

0 commit comments

Comments
 (0)