|
| 1 | +from __future__ import annotations as _annotations |
| 2 | + |
| 3 | +from collections.abc import AsyncIterator, Sequence |
| 4 | +from contextlib import asynccontextmanager |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +from starlette.applications import Starlette |
| 8 | +from starlette.middleware import Middleware |
| 9 | +from starlette.requests import Request |
| 10 | +from starlette.responses import Response |
| 11 | +from starlette.routing import Route |
| 12 | +from starlette.types import ExceptionHandler, Lifespan, Receive, Scope, Send |
| 13 | + |
| 14 | +from .broker import Broker |
| 15 | +from .schema import AgentCard, Provider, Skill, a2a_request_ta, a2a_response_ta, agent_card_ta |
| 16 | +from .storage import Storage |
| 17 | +from .task_manager import TaskManager |
| 18 | + |
| 19 | + |
| 20 | +class FastA2A(Starlette): |
| 21 | + """The main class for the FastA2A library.""" |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + *, |
| 26 | + storage: Storage, |
| 27 | + broker: Broker, |
| 28 | + # Agent card |
| 29 | + name: str | None = None, |
| 30 | + url: str = 'http://localhost:8000', |
| 31 | + version: str = '1.0.0', |
| 32 | + description: str | None = None, |
| 33 | + provider: Provider | None = None, |
| 34 | + skills: list[Skill] | None = None, |
| 35 | + # Starlette |
| 36 | + debug: bool = False, |
| 37 | + routes: Sequence[Route] | None = None, |
| 38 | + middleware: Sequence[Middleware] | None = None, |
| 39 | + exception_handlers: dict[Any, ExceptionHandler] | None = None, |
| 40 | + lifespan: Lifespan[FastA2A] | None = None, |
| 41 | + ): |
| 42 | + if lifespan is None: |
| 43 | + lifespan = _default_lifespan |
| 44 | + |
| 45 | + super().__init__( |
| 46 | + debug=debug, |
| 47 | + routes=routes, |
| 48 | + middleware=middleware, |
| 49 | + exception_handlers=exception_handlers, |
| 50 | + lifespan=lifespan, |
| 51 | + ) |
| 52 | + |
| 53 | + self.name = name or 'Agent' |
| 54 | + self.url = url |
| 55 | + self.version = version |
| 56 | + self.description = description |
| 57 | + self.provider = provider |
| 58 | + self.skills = skills or [] |
| 59 | + # NOTE: For now, I don't think there's any reason to support any other input/output modes. |
| 60 | + self.default_input_modes = ['application/json'] |
| 61 | + self.default_output_modes = ['application/json'] |
| 62 | + |
| 63 | + self.task_manager = TaskManager(broker=broker, storage=storage) |
| 64 | + |
| 65 | + # Setup |
| 66 | + self._agent_card_json_schema: bytes | None = None |
| 67 | + self.router.add_route('/.well-known/agent.json', self._agent_card_endpoint, methods=['HEAD', 'GET', 'OPTIONS']) |
| 68 | + self.router.add_route('/', self._agent_run_endpoint, methods=['POST']) |
| 69 | + |
| 70 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 71 | + if scope['type'] == 'http' and not self.task_manager.is_running: |
| 72 | + raise RuntimeError('TaskManager was not properly initialized.') |
| 73 | + await super().__call__(scope, receive, send) |
| 74 | + |
| 75 | + async def _agent_card_endpoint(self, request: Request) -> Response: |
| 76 | + if self._agent_card_json_schema is None: |
| 77 | + agent_card = AgentCard( |
| 78 | + name=self.name, |
| 79 | + url=self.url, |
| 80 | + version=self.version, |
| 81 | + skills=self.skills, |
| 82 | + default_input_modes=self.default_input_modes, |
| 83 | + default_output_modes=self.default_output_modes, |
| 84 | + ) |
| 85 | + if self.description is not None: |
| 86 | + agent_card['description'] = self.description |
| 87 | + if self.provider is not None: |
| 88 | + agent_card['provider'] = self.provider |
| 89 | + self._agent_card_json_schema = agent_card_ta.dump_json(agent_card) |
| 90 | + return Response(content=self._agent_card_json_schema, media_type='application/json') |
| 91 | + |
| 92 | + async def _agent_run_endpoint(self, request: Request) -> Response: |
| 93 | + """This is the main endpoint for the A2A server. |
| 94 | +
|
| 95 | + Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions. |
| 96 | +
|
| 97 | + 1. The server will always either send a "submitted" or a "failed" on `tasks/send`. |
| 98 | + Never a "completed" on the first message. |
| 99 | + 2. There are three possible ends for the task: |
| 100 | + 2.1. The task was "completed" successfully. |
| 101 | + 2.2. The task was "canceled". |
| 102 | + 2.3. The task "failed". |
| 103 | + 3. The server will send a "working" on the first chunk on `tasks/pushNotification/get`. |
| 104 | + """ |
| 105 | + data = await request.body() |
| 106 | + a2a_request = a2a_request_ta.validate_json(data) |
| 107 | + |
| 108 | + if a2a_request['method'] == 'tasks/send': |
| 109 | + jsonrpc_response = await self.task_manager.send_task(a2a_request) |
| 110 | + elif a2a_request['method'] == 'tasks/get': |
| 111 | + jsonrpc_response = await self.task_manager.get_task(a2a_request) |
| 112 | + elif a2a_request['method'] == 'tasks/cancel': |
| 113 | + jsonrpc_response = await self.task_manager.cancel_task(a2a_request) |
| 114 | + else: |
| 115 | + raise NotImplementedError(f'Method {a2a_request["method"]} not implemented.') |
| 116 | + return Response( |
| 117 | + content=a2a_response_ta.dump_json(jsonrpc_response, by_alias=True), media_type='application/json' |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +@asynccontextmanager |
| 122 | +async def _default_lifespan(app: FastA2A) -> AsyncIterator[None]: |
| 123 | + async with app.task_manager: |
| 124 | + yield |
0 commit comments