Skip to content

Commit 3e13909

Browse files
authored
Python: Add Pydantic request model and OpenAPI tags support to AG-UI FastAPI endpoint (#2522)
* feat(ag-ui): Add Pydantic request model and OpenAPI tags support - Add AGUIRequest Pydantic model in _types.py with field descriptions - Update add_agent_framework_fastapi_endpoint() to accept tags parameter - Use AGUIRequest model for automatic validation and OpenAPI schema generation - Export AGUIRequest and DEFAULT_TAGS in __init__.py - Update test_endpoint.py to expect 422 for invalid requests - Add tests for OpenAPI schema, default tags, custom tags, and validation Benefits: - Better API documentation with complete request schema in Swagger UI - Automatic request validation with Pydantic - Organized endpoints under 'AG-UI' tag instead of 'default' - Improved developer experience and type safety Fixes #<issue-number> * test(ag-ui): Add test for internal error handling to achieve 100% coverage - Add test_endpoint_internal_error_handling() to cover exception handling code - Mock copy.deepcopy to simulate internal error during default_state processing - Add type: ignore for FastAPI tags parameter (known pyright compatibility issue) - Achieves 100% test coverage for _endpoint.py (previously missing lines 103-105)
1 parent 3a5fe31 commit 3e13909

File tree

4 files changed

+156
-9
lines changed

4 files changed

+156
-9
lines changed

python/packages/ag-ui/agent_framework_ag_ui/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,28 @@
1616
from ._endpoint import add_agent_framework_fastapi_endpoint
1717
from ._event_converters import AGUIEventConverter
1818
from ._http_service import AGUIHttpService
19+
from ._types import AGUIRequest
1920

2021
try:
2122
__version__ = importlib.metadata.version(__name__)
2223
except importlib.metadata.PackageNotFoundError:
2324
__version__ = "0.0.0"
2425

26+
# Default OpenAPI tags for AG-UI endpoints
27+
DEFAULT_TAGS = ["AG-UI"]
28+
2529
__all__ = [
2630
"AgentFrameworkAgent",
2731
"add_agent_framework_fastapi_endpoint",
2832
"AGUIChatClient",
2933
"AGUIEventConverter",
3034
"AGUIHttpService",
35+
"AGUIRequest",
3136
"ConfirmationStrategy",
3237
"DefaultConfirmationStrategy",
3338
"TaskPlannerConfirmationStrategy",
3439
"RecipeConfirmationStrategy",
3540
"DocumentWriterConfirmationStrategy",
41+
"DEFAULT_TAGS",
3642
"__version__",
3743
]

python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from ag_ui.encoder import EventEncoder
1010
from agent_framework import AgentProtocol
11-
from fastapi import FastAPI, Request
11+
from fastapi import FastAPI
1212
from fastapi.responses import StreamingResponse
1313

1414
from ._agent import AgentFrameworkAgent
15+
from ._types import AGUIRequest
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -24,6 +25,7 @@ def add_agent_framework_fastapi_endpoint(
2425
predict_state_config: dict[str, dict[str, str]] | None = None,
2526
allow_origins: list[str] | None = None,
2627
default_state: dict[str, Any] | None = None,
28+
tags: list[str] | None = None,
2729
) -> None:
2830
"""Add an AG-UI endpoint to a FastAPI app.
2931
@@ -36,6 +38,7 @@ def add_agent_framework_fastapi_endpoint(
3638
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
3739
allow_origins: CORS origins (not yet implemented)
3840
default_state: Optional initial state to seed when the client does not provide state keys
41+
tags: OpenAPI tags for endpoint categorization (defaults to ["AG-UI"])
3942
"""
4043
if isinstance(agent, AgentProtocol):
4144
wrapped_agent = AgentFrameworkAgent(
@@ -46,15 +49,15 @@ def add_agent_framework_fastapi_endpoint(
4649
else:
4750
wrapped_agent = agent
4851

49-
@app.post(path)
50-
async def agent_endpoint(request: Request): # type: ignore[misc]
52+
@app.post(path, tags=tags or ["AG-UI"]) # type: ignore[arg-type]
53+
async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc]
5154
"""Handle AG-UI agent requests.
5255
5356
Note: Function is accessed via FastAPI's decorator registration,
5457
despite appearing unused to static analysis.
5558
"""
5659
try:
57-
input_data = await request.json()
60+
input_data = request_body.model_dump(exclude_none=True)
5861
if default_state:
5962
state = input_data.setdefault("state", {})
6063
for key, value in default_state.items():

python/packages/ag-ui/agent_framework_ag_ui/_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from typing import Any, TypedDict
66

7+
from pydantic import BaseModel, Field
8+
79

810
class PredictStateConfig(TypedDict):
911
"""Configuration for predictive state updates."""
@@ -25,3 +27,24 @@ class AgentState(TypedDict):
2527
"""Base state for AG-UI agents."""
2628

2729
messages: list[Any] | None
30+
31+
32+
class AGUIRequest(BaseModel):
33+
"""Request model for AG-UI endpoints."""
34+
35+
messages: list[dict[str, Any]] = Field(
36+
...,
37+
description="AG-UI format messages array",
38+
)
39+
run_id: str | None = Field(
40+
None,
41+
description="Optional run identifier for tracking",
42+
)
43+
thread_id: str | None = Field(
44+
None,
45+
description="Optional thread identifier for conversation context",
46+
)
47+
state: dict[str, Any] | None = Field(
48+
None,
49+
description="Optional shared state for agentic generative UI",
50+
)

python/packages/ag-ui/tests/test_endpoint.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,8 @@ async def test_endpoint_error_handling():
176176
# Send invalid JSON to trigger parsing error before streaming
177177
response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore
178178

179-
# The exception handler catches it and returns JSON error
180-
assert response.status_code == 200
181-
content = json.loads(response.content)
182-
assert "error" in content
183-
assert content["error"] == "An internal error has occurred."
179+
# Pydantic validation now returns 422 for invalid request body
180+
assert response.status_code == 422
184181

185182

186183
async def test_endpoint_multiple_paths():
@@ -266,3 +263,121 @@ async def test_endpoint_complex_input():
266263
)
267264

268265
assert response.status_code == 200
266+
267+
268+
async def test_endpoint_openapi_schema():
269+
"""Test that endpoint generates proper OpenAPI schema with request model."""
270+
app = FastAPI()
271+
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
272+
273+
add_agent_framework_fastapi_endpoint(app, agent, path="/schema-test")
274+
275+
client = TestClient(app)
276+
response = client.get("/openapi.json")
277+
278+
assert response.status_code == 200
279+
openapi_spec = response.json()
280+
281+
# Verify the endpoint exists in the schema
282+
assert "/schema-test" in openapi_spec["paths"]
283+
endpoint_spec = openapi_spec["paths"]["/schema-test"]["post"]
284+
285+
# Verify request body schema is defined
286+
assert "requestBody" in endpoint_spec
287+
request_body = endpoint_spec["requestBody"]
288+
assert "content" in request_body
289+
assert "application/json" in request_body["content"]
290+
291+
# Verify schema references AGUIRequest model
292+
schema_ref = request_body["content"]["application/json"]["schema"]
293+
assert "$ref" in schema_ref
294+
assert "AGUIRequest" in schema_ref["$ref"]
295+
296+
# Verify AGUIRequest model is in components
297+
assert "components" in openapi_spec
298+
assert "schemas" in openapi_spec["components"]
299+
assert "AGUIRequest" in openapi_spec["components"]["schemas"]
300+
301+
# Verify AGUIRequest has required fields
302+
agui_request_schema = openapi_spec["components"]["schemas"]["AGUIRequest"]
303+
assert "properties" in agui_request_schema
304+
assert "messages" in agui_request_schema["properties"]
305+
assert "run_id" in agui_request_schema["properties"]
306+
assert "thread_id" in agui_request_schema["properties"]
307+
assert "state" in agui_request_schema["properties"]
308+
assert "required" in agui_request_schema
309+
assert "messages" in agui_request_schema["required"]
310+
311+
312+
async def test_endpoint_default_tags():
313+
"""Test that endpoint uses default 'AG-UI' tag."""
314+
app = FastAPI()
315+
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
316+
317+
add_agent_framework_fastapi_endpoint(app, agent, path="/default-tags")
318+
319+
client = TestClient(app)
320+
response = client.get("/openapi.json")
321+
322+
assert response.status_code == 200
323+
openapi_spec = response.json()
324+
325+
endpoint_spec = openapi_spec["paths"]["/default-tags"]["post"]
326+
assert "tags" in endpoint_spec
327+
assert endpoint_spec["tags"] == ["AG-UI"]
328+
329+
330+
async def test_endpoint_custom_tags():
331+
"""Test that endpoint accepts custom tags."""
332+
app = FastAPI()
333+
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
334+
335+
add_agent_framework_fastapi_endpoint(app, agent, path="/custom-tags", tags=["Custom", "Agent"])
336+
337+
client = TestClient(app)
338+
response = client.get("/openapi.json")
339+
340+
assert response.status_code == 200
341+
openapi_spec = response.json()
342+
343+
endpoint_spec = openapi_spec["paths"]["/custom-tags"]["post"]
344+
assert "tags" in endpoint_spec
345+
assert endpoint_spec["tags"] == ["Custom", "Agent"]
346+
347+
348+
async def test_endpoint_missing_required_field():
349+
"""Test that endpoint validates required fields with Pydantic."""
350+
app = FastAPI()
351+
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
352+
353+
add_agent_framework_fastapi_endpoint(app, agent, path="/validation")
354+
355+
client = TestClient(app)
356+
357+
# Missing required 'messages' field should trigger validation error
358+
response = client.post("/validation", json={"run_id": "test-123"})
359+
360+
assert response.status_code == 422
361+
error_detail = response.json()
362+
assert "detail" in error_detail
363+
364+
365+
async def test_endpoint_internal_error_handling():
366+
"""Test endpoint error handling when an exception occurs before streaming starts."""
367+
from unittest.mock import patch
368+
369+
app = FastAPI()
370+
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
371+
372+
# Use default_state to trigger the code path that can raise an exception
373+
add_agent_framework_fastapi_endpoint(app, agent, path="/error-test", default_state={"key": "value"})
374+
375+
client = TestClient(app)
376+
377+
# Mock copy.deepcopy to raise an exception during default_state processing
378+
with patch("agent_framework_ag_ui._endpoint.copy.deepcopy") as mock_deepcopy:
379+
mock_deepcopy.side_effect = Exception("Simulated internal error")
380+
response = client.post("/error-test", json={"messages": [{"role": "user", "content": "Hello"}]})
381+
382+
assert response.status_code == 200
383+
assert response.json() == {"error": "An internal error has occurred."}

0 commit comments

Comments
 (0)