Skip to content

Commit b81d7ab

Browse files
feat: Agent to OpenAI endpoint router
Co-authored-by: Ion Koutsouris <[email protected]>
1 parent 063278e commit b81d7ab

File tree

18 files changed

+6698
-4212
lines changed

18 files changed

+6698
-4212
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
2+
from pydantic_ai.fastapi.registry import AgentRegistry
3+
4+
__all__ = [
5+
'AgentRegistry',
6+
'AgentAPIRouter',
7+
]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import logging
2+
from typing import Any
3+
4+
from fastapi import APIRouter, HTTPException
5+
from fastapi.responses import StreamingResponse
6+
from openai.types import ErrorObject
7+
from openai.types.chat.chat_completion import ChatCompletion
8+
from openai.types.model import Model
9+
from openai.types.responses import Response
10+
11+
from pydantic_ai.fastapi.api import AgentChatCompletionsAPI, AgentModelsAPI, AgentResponsesAPI
12+
from pydantic_ai.fastapi.data_models import (
13+
ChatCompletionRequest,
14+
ErrorResponse,
15+
ModelsResponse,
16+
ResponsesRequest,
17+
)
18+
from pydantic_ai.fastapi.registry import AgentRegistry
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class AgentAPIRouter(APIRouter):
24+
"""FastAPI Router for Pydantic Agent."""
25+
26+
def __init__(
27+
self,
28+
agent_registry: AgentRegistry,
29+
disable_response_api: bool = False,
30+
disable_completions_api: bool = False,
31+
*args: tuple[Any],
32+
**kwargs: tuple[Any],
33+
):
34+
super().__init__(*args, **kwargs)
35+
self.registry = agent_registry
36+
self.responses_api = AgentResponsesAPI(self.registry)
37+
self.completions_api = AgentChatCompletionsAPI(self.registry)
38+
self.models_api = AgentModelsAPI(self.registry)
39+
self.enable_responses_api = not disable_response_api
40+
self.enable_completions_api = not disable_completions_api
41+
42+
# Registers OpenAI/v1 API routes
43+
self._register_routes()
44+
45+
def _register_routes(self) -> None: # noqa: C901
46+
if self.enable_completions_api:
47+
48+
@self.post(
49+
'/v1/chat/completions',
50+
response_model=ChatCompletion,
51+
)
52+
async def chat_completions( # type: ignore
53+
request: ChatCompletionRequest,
54+
) -> ChatCompletion | StreamingResponse:
55+
if not request.messages:
56+
raise HTTPException(
57+
status_code=400,
58+
detail=ErrorResponse(
59+
error=ErrorObject(
60+
type='invalid_request_error',
61+
message='Messages cannot be empty',
62+
),
63+
).model_dump(),
64+
)
65+
try:
66+
if getattr(request, 'stream', False):
67+
return StreamingResponse(
68+
self.completions_api.create_streaming_completion(request),
69+
media_type='text/event-stream',
70+
headers={
71+
'Cache-Control': 'no-cache',
72+
'Connection': 'keep-alive',
73+
'Content-Type': 'text/plain; charset=utf-8',
74+
},
75+
)
76+
else:
77+
return await self.completions_api.create_completion(request)
78+
except Exception as e:
79+
logger.error(f'Error in chat completion: {e}', exc_info=True)
80+
raise HTTPException(
81+
status_code=500,
82+
detail=ErrorResponse(
83+
error=ErrorObject(
84+
type='internal_server_error',
85+
message=str(e),
86+
),
87+
).model_dump(),
88+
)
89+
90+
if self.enable_responses_api:
91+
92+
@self.post(
93+
'/v1/responses',
94+
response_model=Response,
95+
)
96+
async def responses( # type: ignore
97+
request: ResponsesRequest,
98+
) -> Response:
99+
if not request.input:
100+
raise HTTPException(
101+
status_code=400,
102+
detail=ErrorResponse(
103+
error=ErrorObject(
104+
type='invalid_request_error',
105+
message='Messages cannot be empty',
106+
),
107+
).model_dump(),
108+
)
109+
try:
110+
if getattr(request, 'stream', False):
111+
# TODO: add streaming support for responses api
112+
raise HTTPException(status_code=501)
113+
else:
114+
return await self.responses_api.create_response(request)
115+
except Exception as e:
116+
logger.error(f'Error in responses: {e}', exc_info=True)
117+
raise HTTPException(
118+
status_code=500,
119+
detail=ErrorResponse(
120+
error=ErrorObject(
121+
type='internal_server_error',
122+
message=str(e),
123+
),
124+
).model_dump(),
125+
)
126+
127+
@self.get('/v1/models', response_model=ModelsResponse)
128+
async def get_models() -> ModelsResponse: # type: ignore
129+
try:
130+
return await self.models_api.list_models()
131+
except Exception as e:
132+
logger.error(f'Error listing models: {e}', exc_info=True)
133+
raise HTTPException(
134+
status_code=500,
135+
detail=ErrorResponse(
136+
error=ErrorObject(
137+
type='internal_server_error',
138+
message=f'Error retrieving models: {str(e)}',
139+
),
140+
).model_dump(),
141+
)
142+
143+
@self.get('/v1/models' + '/{model_id}', response_model=Model)
144+
async def get_model(model_id: str) -> Model: # type: ignore
145+
try:
146+
return await self.models_api.get_model(model_id)
147+
except HTTPException:
148+
raise
149+
except Exception as e:
150+
logger.error(f'Error fetching model info: {e}', exc_info=True)
151+
raise HTTPException(
152+
status_code=500,
153+
detail=ErrorResponse(
154+
error=ErrorObject(
155+
type='internal_server_error',
156+
message=f'Error retrieving model: {str(e)}',
157+
),
158+
).model_dump(),
159+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic_ai.fastapi.api.completions import AgentChatCompletionsAPI
2+
from pydantic_ai.fastapi.api.models import AgentModelsAPI
3+
from pydantic_ai.fastapi.api.responses import AgentResponsesAPI
4+
5+
__all__ = [
6+
'AgentChatCompletionsAPI',
7+
'AgentModelsAPI',
8+
'AgentResponsesAPI',
9+
]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import json
2+
import logging
3+
import time
4+
from collections.abc import AsyncGenerator
5+
from typing import Any
6+
7+
from fastapi import HTTPException
8+
from openai.types import ErrorObject
9+
from openai.types.chat.chat_completion import ChatCompletion
10+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice as Chunkhoice, ChoiceDelta
11+
from pydantic import TypeAdapter
12+
13+
from pydantic_ai import Agent, _utils
14+
from pydantic_ai.fastapi.convert import (
15+
openai_chat_completions_2pai,
16+
pai_result_to_openai_completions,
17+
)
18+
from pydantic_ai.fastapi.data_models import ChatCompletionRequest, ErrorResponse
19+
from pydantic_ai.fastapi.registry import AgentRegistry
20+
from pydantic_ai.settings import ModelSettings
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class AgentChatCompletionsAPI:
26+
"""Chat completions API openai <-> pydantic-ai conversion."""
27+
28+
def __init__(self, registry: AgentRegistry) -> None:
29+
self.registry = registry
30+
31+
def get_agent(self, name: str) -> Agent:
32+
"""Retrieves agent."""
33+
try:
34+
agent = self.registry.get_completions_agent(name)
35+
except KeyError:
36+
raise HTTPException(
37+
status_code=404,
38+
detail=ErrorResponse(
39+
error=ErrorObject(
40+
message=f'Model {name} is not available as chat completions API',
41+
type='not_found_error',
42+
),
43+
).model_dump(),
44+
)
45+
46+
return agent
47+
48+
async def create_completion(self, request: ChatCompletionRequest) -> ChatCompletion:
49+
"""Create a non-streaming chat completion."""
50+
model_name = request.model
51+
agent = self.get_agent(model_name)
52+
53+
model_settings_ta = TypeAdapter(ModelSettings)
54+
messages = openai_chat_completions_2pai(messages=request.messages)
55+
56+
try:
57+
async with agent:
58+
result = await agent.run(
59+
message_history=messages,
60+
model_settings=model_settings_ta.validate_python(
61+
{k: v for k, v in request.model_dump().items() if v is not None},
62+
),
63+
)
64+
65+
return pai_result_to_openai_completions(
66+
result=result,
67+
model=model_name,
68+
)
69+
70+
except Exception as e:
71+
logger.error(f'Error creating completion: {e}')
72+
raise
73+
74+
async def create_streaming_completion(self, request: ChatCompletionRequest) -> AsyncGenerator[str]:
75+
"""Create a streaming chat completion."""
76+
model_name = request.model
77+
agent = self.get_agent(model_name)
78+
messages = openai_chat_completions_2pai(messages=request.messages)
79+
80+
role_sent = False
81+
82+
async with (
83+
agent,
84+
agent.run_stream(
85+
message_history=messages,
86+
) as result,
87+
):
88+
async for chunk in result.stream_text(delta=True):
89+
delta = ChoiceDelta(
90+
role='assistant' if not role_sent else None,
91+
content=chunk,
92+
)
93+
role_sent = True
94+
95+
stream_response = ChatCompletionChunk(
96+
id=f'chatcmpl-{_utils.now_utc().isoformat()}',
97+
created=int(_utils.now_utc().timestamp()),
98+
model=model_name,
99+
object='chat.completion.chunk',
100+
choices=[
101+
Chunkhoice(
102+
index=0,
103+
delta=delta,
104+
),
105+
],
106+
)
107+
108+
yield f'data: {stream_response.model_dump_json()}\n\n'
109+
110+
final_chunk: dict[str, Any] = {
111+
'id': f'chatcmpl-{int(time.time())}',
112+
'object': 'chat.completion.chunk',
113+
'model': model_name,
114+
'choices': [
115+
{
116+
'index': 0,
117+
'delta': {},
118+
'finish_reason': 'stop',
119+
},
120+
],
121+
}
122+
yield f'data: {json.dumps(final_chunk)}\n\n'
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
import time
3+
4+
from fastapi import HTTPException
5+
from openai.types import ErrorObject
6+
from openai.types.model import Model
7+
8+
from pydantic_ai.fastapi.data_models import (
9+
ErrorResponse,
10+
ModelsResponse,
11+
)
12+
from pydantic_ai.fastapi.registry import AgentRegistry
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class AgentModelsAPI:
18+
"""Models API for pydantic-ai agents."""
19+
20+
def __init__(self, registry: AgentRegistry) -> None:
21+
self.registry = registry
22+
23+
async def list_models(self) -> ModelsResponse:
24+
"""List available models (OpenAI-compatible endpoint)."""
25+
agents = self.registry.all_agents
26+
27+
models = [
28+
Model(
29+
id=name,
30+
object='model',
31+
created=int(time.time()),
32+
owned_by='model_owner',
33+
)
34+
for name in agents
35+
]
36+
return ModelsResponse(data=models)
37+
38+
async def get_model(self, name: str) -> Model:
39+
"""Get information about a specific model (OpenAI-compatible endpoint)."""
40+
if name in self.registry.all_agents:
41+
return Model(id=name, object='model', created=int(time.time()), owned_by='NDIA')
42+
else:
43+
raise HTTPException(
44+
status_code=404,
45+
detail=ErrorResponse(
46+
error=ErrorObject(
47+
type='not_found_error',
48+
message=f"Model '{name}' not found",
49+
),
50+
).model_dump(),
51+
)

0 commit comments

Comments
 (0)