Skip to content

Commit 4705ba1

Browse files
feat: Agent to OpenAI endpoint router
Co-authored-by: Ion Koutsouris <[email protected]>
1 parent 063278e commit 4705ba1

File tree

19 files changed

+6724
-4213
lines changed

19 files changed

+6724
-4213
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic_ai.fastapi.agent_router import AgentAPIRouter
2+
from pydantic_ai.fastapi.registry import AgentRegistry
3+
4+
__all__ = [
5+
'AgentRegistry',
6+
'AgentAPIRouter',
7+
]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import logging
2+
from typing import Any
3+
4+
try:
5+
from fastapi import APIRouter, HTTPException
6+
from fastapi.responses import StreamingResponse
7+
except ImportError as _import_error:
8+
raise ImportError(
9+
'Please install the `fastapi` package to use the fastapi router. '
10+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
11+
)
12+
from openai.types import ErrorObject
13+
from openai.types.chat.chat_completion import ChatCompletion
14+
from openai.types.model import Model
15+
from openai.types.responses import Response
16+
17+
from pydantic_ai.fastapi.api import AgentChatCompletionsAPI, AgentModelsAPI, AgentResponsesAPI
18+
from pydantic_ai.fastapi.data_models import (
19+
ChatCompletionRequest,
20+
ErrorResponse,
21+
ModelsResponse,
22+
ResponsesRequest,
23+
)
24+
from pydantic_ai.fastapi.registry import AgentRegistry
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class AgentAPIRouter(APIRouter):
30+
"""FastAPI Router for Pydantic Agent."""
31+
32+
def __init__(
33+
self,
34+
agent_registry: AgentRegistry,
35+
disable_response_api: bool = False,
36+
disable_completions_api: bool = False,
37+
*args: tuple[Any],
38+
**kwargs: tuple[Any],
39+
):
40+
super().__init__(*args, **kwargs)
41+
self.registry = agent_registry
42+
self.responses_api = AgentResponsesAPI(self.registry)
43+
self.completions_api = AgentChatCompletionsAPI(self.registry)
44+
self.models_api = AgentModelsAPI(self.registry)
45+
self.enable_responses_api = not disable_response_api
46+
self.enable_completions_api = not disable_completions_api
47+
48+
# Registers OpenAI/v1 API routes
49+
self._register_routes()
50+
51+
def _register_routes(self) -> None: # noqa: C901
52+
if self.enable_completions_api:
53+
54+
@self.post(
55+
'/v1/chat/completions',
56+
response_model=ChatCompletion,
57+
)
58+
async def chat_completions( # type: ignore
59+
request: ChatCompletionRequest,
60+
) -> ChatCompletion | StreamingResponse:
61+
if not request.messages:
62+
raise HTTPException(
63+
status_code=400,
64+
detail=ErrorResponse(
65+
error=ErrorObject(
66+
type='invalid_request_error',
67+
message='Messages cannot be empty',
68+
),
69+
).model_dump(),
70+
)
71+
try:
72+
if getattr(request, 'stream', False):
73+
return StreamingResponse(
74+
self.completions_api.create_streaming_completion(request),
75+
media_type='text/event-stream',
76+
headers={
77+
'Cache-Control': 'no-cache',
78+
'Connection': 'keep-alive',
79+
'Content-Type': 'text/plain; charset=utf-8',
80+
},
81+
)
82+
else:
83+
return await self.completions_api.create_completion(request)
84+
except Exception as e:
85+
logger.error(f'Error in chat completion: {e}', exc_info=True)
86+
raise HTTPException(
87+
status_code=500,
88+
detail=ErrorResponse(
89+
error=ErrorObject(
90+
type='internal_server_error',
91+
message=str(e),
92+
),
93+
).model_dump(),
94+
)
95+
96+
if self.enable_responses_api:
97+
98+
@self.post(
99+
'/v1/responses',
100+
response_model=Response,
101+
)
102+
async def responses( # type: ignore
103+
request: ResponsesRequest,
104+
) -> Response:
105+
if not request.input:
106+
raise HTTPException(
107+
status_code=400,
108+
detail=ErrorResponse(
109+
error=ErrorObject(
110+
type='invalid_request_error',
111+
message='Messages cannot be empty',
112+
),
113+
).model_dump(),
114+
)
115+
try:
116+
if getattr(request, 'stream', False):
117+
# TODO: add streaming support for responses api
118+
raise HTTPException(status_code=501)
119+
else:
120+
return await self.responses_api.create_response(request)
121+
except Exception as e:
122+
logger.error(f'Error in responses: {e}', exc_info=True)
123+
raise HTTPException(
124+
status_code=500,
125+
detail=ErrorResponse(
126+
error=ErrorObject(
127+
type='internal_server_error',
128+
message=str(e),
129+
),
130+
).model_dump(),
131+
)
132+
133+
@self.get('/v1/models', response_model=ModelsResponse)
134+
async def get_models() -> ModelsResponse: # type: ignore
135+
try:
136+
return await self.models_api.list_models()
137+
except Exception as e:
138+
logger.error(f'Error listing models: {e}', exc_info=True)
139+
raise HTTPException(
140+
status_code=500,
141+
detail=ErrorResponse(
142+
error=ErrorObject(
143+
type='internal_server_error',
144+
message=f'Error retrieving models: {str(e)}',
145+
),
146+
).model_dump(),
147+
)
148+
149+
@self.get('/v1/models' + '/{model_id}', response_model=Model)
150+
async def get_model(model_id: str) -> Model: # type: ignore
151+
try:
152+
return await self.models_api.get_model(model_id)
153+
except HTTPException:
154+
raise
155+
except Exception as e:
156+
logger.error(f'Error fetching model info: {e}', exc_info=True)
157+
raise HTTPException(
158+
status_code=500,
159+
detail=ErrorResponse(
160+
error=ErrorObject(
161+
type='internal_server_error',
162+
message=f'Error retrieving model: {str(e)}',
163+
),
164+
).model_dump(),
165+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic_ai.fastapi.api.completions import AgentChatCompletionsAPI
2+
from pydantic_ai.fastapi.api.models import AgentModelsAPI
3+
from pydantic_ai.fastapi.api.responses import AgentResponsesAPI
4+
5+
__all__ = [
6+
'AgentChatCompletionsAPI',
7+
'AgentModelsAPI',
8+
'AgentResponsesAPI',
9+
]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import json
2+
import logging
3+
import time
4+
from collections.abc import AsyncGenerator
5+
from typing import Any
6+
7+
try:
8+
from fastapi import HTTPException
9+
except ImportError as _import_error:
10+
raise ImportError(
11+
'Please install the `fastapi` package to use the fastapi router. '
12+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
13+
)
14+
from openai.types import ErrorObject
15+
from openai.types.chat.chat_completion import ChatCompletion
16+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice as Chunkhoice, ChoiceDelta
17+
from pydantic import TypeAdapter
18+
19+
from pydantic_ai import Agent, _utils
20+
from pydantic_ai.fastapi.convert import (
21+
openai_chat_completions_2pai,
22+
pai_result_to_openai_completions,
23+
)
24+
from pydantic_ai.fastapi.data_models import ChatCompletionRequest, ErrorResponse
25+
from pydantic_ai.fastapi.registry import AgentRegistry
26+
from pydantic_ai.settings import ModelSettings
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class AgentChatCompletionsAPI:
32+
"""Chat completions API openai <-> pydantic-ai conversion."""
33+
34+
def __init__(self, registry: AgentRegistry) -> None:
35+
self.registry = registry
36+
37+
def get_agent(self, name: str) -> Agent:
38+
"""Retrieves agent."""
39+
try:
40+
agent = self.registry.get_completions_agent(name)
41+
except KeyError:
42+
raise HTTPException(
43+
status_code=404,
44+
detail=ErrorResponse(
45+
error=ErrorObject(
46+
message=f'Model {name} is not available as chat completions API',
47+
type='not_found_error',
48+
),
49+
).model_dump(),
50+
)
51+
52+
return agent
53+
54+
async def create_completion(self, request: ChatCompletionRequest) -> ChatCompletion:
55+
"""Create a non-streaming chat completion."""
56+
model_name = request.model
57+
agent = self.get_agent(model_name)
58+
59+
model_settings_ta = TypeAdapter(ModelSettings)
60+
messages = openai_chat_completions_2pai(messages=request.messages)
61+
62+
try:
63+
async with agent:
64+
result = await agent.run(
65+
message_history=messages,
66+
model_settings=model_settings_ta.validate_python(
67+
{k: v for k, v in request.model_dump().items() if v is not None},
68+
),
69+
)
70+
71+
return pai_result_to_openai_completions(
72+
result=result,
73+
model=model_name,
74+
)
75+
76+
except Exception as e:
77+
logger.error(f'Error creating completion: {e}')
78+
raise
79+
80+
async def create_streaming_completion(self, request: ChatCompletionRequest) -> AsyncGenerator[str]:
81+
"""Create a streaming chat completion."""
82+
model_name = request.model
83+
agent = self.get_agent(model_name)
84+
messages = openai_chat_completions_2pai(messages=request.messages)
85+
86+
role_sent = False
87+
88+
async with (
89+
agent,
90+
agent.run_stream(
91+
message_history=messages,
92+
) as result,
93+
):
94+
async for chunk in result.stream_text(delta=True):
95+
delta = ChoiceDelta(
96+
role='assistant' if not role_sent else None,
97+
content=chunk,
98+
)
99+
role_sent = True
100+
101+
stream_response = ChatCompletionChunk(
102+
id=f'chatcmpl-{_utils.now_utc().isoformat()}',
103+
created=int(_utils.now_utc().timestamp()),
104+
model=model_name,
105+
object='chat.completion.chunk',
106+
choices=[
107+
Chunkhoice(
108+
index=0,
109+
delta=delta,
110+
),
111+
],
112+
)
113+
114+
yield f'data: {stream_response.model_dump_json()}\n\n'
115+
116+
final_chunk: dict[str, Any] = {
117+
'id': f'chatcmpl-{int(time.time())}',
118+
'object': 'chat.completion.chunk',
119+
'model': model_name,
120+
'choices': [
121+
{
122+
'index': 0,
123+
'delta': {},
124+
'finish_reason': 'stop',
125+
},
126+
],
127+
}
128+
yield f'data: {json.dumps(final_chunk)}\n\n'
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import time
3+
4+
try:
5+
from fastapi import HTTPException
6+
except ImportError as _import_error:
7+
raise ImportError(
8+
'Please install the `fastapi` package to use the fastapi router. '
9+
'you can use the `fastapi` optional group — `pip install "pydantic-ai-slim[fastapi]"`'
10+
)
11+
from openai.types import ErrorObject
12+
from openai.types.model import Model
13+
14+
from pydantic_ai.fastapi.data_models import (
15+
ErrorResponse,
16+
ModelsResponse,
17+
)
18+
from pydantic_ai.fastapi.registry import AgentRegistry
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class AgentModelsAPI:
24+
"""Models API for pydantic-ai agents."""
25+
26+
def __init__(self, registry: AgentRegistry) -> None:
27+
self.registry = registry
28+
29+
async def list_models(self) -> ModelsResponse:
30+
"""List available models (OpenAI-compatible endpoint)."""
31+
agents = self.registry.all_agents
32+
33+
models = [
34+
Model(
35+
id=name,
36+
object='model',
37+
created=int(time.time()),
38+
owned_by='model_owner',
39+
)
40+
for name in agents
41+
]
42+
return ModelsResponse(data=models)
43+
44+
async def get_model(self, name: str) -> Model:
45+
"""Get information about a specific model (OpenAI-compatible endpoint)."""
46+
if name in self.registry.all_agents:
47+
return Model(id=name, object='model', created=int(time.time()), owned_by='NDIA')
48+
else:
49+
raise HTTPException(
50+
status_code=404,
51+
detail=ErrorResponse(
52+
error=ErrorObject(
53+
type='not_found_error',
54+
message=f"Model '{name}' not found",
55+
),
56+
).model_dump(),
57+
)

0 commit comments

Comments
 (0)