|
1 | 1 | import logging |
| 2 | +import sys |
2 | 3 |
|
3 | | -from collections.abc import AsyncIterator |
4 | | -from contextlib import asynccontextmanager |
5 | 4 | from typing import Any |
6 | 5 |
|
7 | 6 | from fastapi import FastAPI |
8 | 7 |
|
| 8 | + |
| 9 | +if sys.version_info < (3, 12): # pragma: no cover |
| 10 | + from typing_extensions import override |
| 11 | +else: # pragma: no cover |
| 12 | + from typing import override |
| 13 | + |
9 | 14 | from a2a.server.apps.jsonrpc.jsonrpc_app import ( |
10 | 15 | CallContextBuilder, |
11 | 16 | JSONRPCApplication, |
|
22 | 27 | logger = logging.getLogger(__name__) |
23 | 28 |
|
24 | 29 |
|
| 30 | +class A2AFastAPI(FastAPI): |
| 31 | + """A FastAPI application that adds A2A-specific OpenAPI components.""" |
| 32 | + |
| 33 | + a2a_components_added: bool = False |
| 34 | + |
| 35 | + @override |
| 36 | + def openapi(self) -> dict[str, Any]: |
| 37 | + openapi_schema = super().openapi() |
| 38 | + if not self.a2a_components_added: |
| 39 | + a2a_request_schema = A2ARequest.model_json_schema( |
| 40 | + ref_template='#/components/schemas/{model}' |
| 41 | + ) |
| 42 | + defs = a2a_request_schema.pop('$defs', {}) |
| 43 | + component_schemas = openapi_schema.setdefault( |
| 44 | + 'components', {} |
| 45 | + ).setdefault('schemas', {}) |
| 46 | + component_schemas.update(defs) |
| 47 | + component_schemas['A2ARequest'] = a2a_request_schema |
| 48 | + self.a2a_components_added = True |
| 49 | + return openapi_schema |
| 50 | + |
| 51 | + |
25 | 52 | class A2AFastAPIApplication(JSONRPCApplication): |
26 | 53 | """A FastAPI application implementing the A2A protocol server endpoints. |
27 | 54 |
|
@@ -112,23 +139,7 @@ def build( |
112 | 139 | Returns: |
113 | 140 | A configured FastAPI application instance. |
114 | 141 | """ |
115 | | - |
116 | | - @asynccontextmanager |
117 | | - async def lifespan(app: FastAPI) -> AsyncIterator[None]: |
118 | | - a2a_request_schema = A2ARequest.model_json_schema( |
119 | | - ref_template='#/components/schemas/{model}' |
120 | | - ) |
121 | | - defs = a2a_request_schema.pop('$defs', {}) |
122 | | - openapi_schema = app.openapi() |
123 | | - component_schemas = openapi_schema.setdefault( |
124 | | - 'components', {} |
125 | | - ).setdefault('schemas', {}) |
126 | | - component_schemas.update(defs) |
127 | | - component_schemas['A2ARequest'] = a2a_request_schema |
128 | | - |
129 | | - yield |
130 | | - |
131 | | - app = FastAPI(lifespan=lifespan, **kwargs) |
| 142 | + app = A2AFastAPI(**kwargs) |
132 | 143 |
|
133 | 144 | self.add_routes_to_app( |
134 | 145 | app, agent_card_url, rpc_url, extended_agent_card_url |
|
0 commit comments