diff --git a/apps/agentstack-sdk-py/pyproject.toml b/apps/agentstack-sdk-py/pyproject.toml index fc75136df..fa99bc349 100644 --- a/apps/agentstack-sdk-py/pyproject.toml +++ b/apps/agentstack-sdk-py/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" authors = [{ name = "IBM Corp." }] requires-python = ">=3.11,<3.14" dependencies = [ - "a2a-sdk==0.3.7", + "a2a-sdk==0.3.9", "objprint>=0.3.0", "uvicorn>=0.35.0", "asyncclick>=8.1.8", diff --git a/apps/agentstack-sdk-py/uv.lock b/apps/agentstack-sdk-py/uv.lock index 0a161a43b..bfcafebfa 100644 --- a/apps/agentstack-sdk-py/uv.lock +++ b/apps/agentstack-sdk-py/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "a2a-sdk" -version = "0.3.7" +version = "0.3.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -17,9 +17,9 @@ dependencies = [ { name = "protobuf" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/ad/b6ecb58f44459a24f1c260e91304e1ddbb7a8e213f1f82cc4c074f66e9bb/a2a_sdk-0.3.7.tar.gz", hash = "sha256:795aa2bd2cfb3c9e8654a1352bf5f75d6cf1205b262b1bf8f4003b5308267ea2", size = 223426, upload-time = "2025-09-23T16:27:29.585Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/0b/80671e784f61b55ac4c340d125d121ba91eba58ad7ba0f03b53b3831cd32/a2a_sdk-0.3.9.tar.gz", hash = "sha256:1dff7b5b1cab0b221519d0faed50176e200a1a87a8de8b64308d876505cc7c77", size = 224528, upload-time = "2025-10-15T17:35:28.299Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/27/9cf8c6de4ae71e9c98ec96b3304449d5d0cd36ec3b95e66b6e7f58a9e571/a2a_sdk-0.3.7-py3-none-any.whl", hash = "sha256:0813b8fd7add427b2b56895cf28cae705303cf6d671b305c0aac69987816e03e", size = 137957, upload-time = "2025-09-23T16:27:27.546Z" }, + { url = "https://files.pythonhosted.org/packages/34/ee/53b2da6d2768b136f996b8c6ab00ebcc44852f9a33816a64deaca6b279fe/a2a_sdk-0.3.9-py3-none-any.whl", hash = "sha256:7ed03a915bae98def46ea0313786da0a7a488346c3dc8af88407bb0b2a763926", size = 139027, upload-time = "2025-10-15T17:35:26.628Z" }, ] [[package]] @@ -238,7 +238,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", specifier = "==0.3.7" }, + { name = "a2a-sdk", specifier = "==0.3.9" }, { name = "anyio", specifier = ">=4.9.0" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "fastapi", specifier = ">=0.116.1" }, diff --git a/apps/beeai-cli/src/beeai_cli/api.py b/apps/beeai-cli/src/beeai_cli/api.py index 27cfb8779..6ab5a259e 100644 --- a/apps/beeai-cli/src/beeai_cli/api.py +++ b/apps/beeai-cli/src/beeai_cli/api.py @@ -112,7 +112,7 @@ async def a2a_client(agent_card: AgentCard, use_auth: bool = True) -> AsyncItera follow_redirects=True, timeout=timedelta(hours=1).total_seconds(), ) as httpx_client: - yield ClientFactory(ClientConfig(httpx_client=httpx_client)).create(card=agent_card) + yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(card=agent_card) @asynccontextmanager diff --git a/apps/beeai-cli/uv.lock b/apps/beeai-cli/uv.lock index 0553c63f5..28ad97de3 100644 --- a/apps/beeai-cli/uv.lock +++ b/apps/beeai-cli/uv.lock @@ -4,7 +4,7 @@ requires-python = "==3.13.*" [[package]] name = "a2a-sdk" -version = "0.3.7" +version = "0.3.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -13,9 +13,9 @@ dependencies = [ { name = "protobuf" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/ad/b6ecb58f44459a24f1c260e91304e1ddbb7a8e213f1f82cc4c074f66e9bb/a2a_sdk-0.3.7.tar.gz", hash = "sha256:795aa2bd2cfb3c9e8654a1352bf5f75d6cf1205b262b1bf8f4003b5308267ea2", size = 223426, upload-time = "2025-09-23T16:27:29.585Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/0b/80671e784f61b55ac4c340d125d121ba91eba58ad7ba0f03b53b3831cd32/a2a_sdk-0.3.9.tar.gz", hash = "sha256:1dff7b5b1cab0b221519d0faed50176e200a1a87a8de8b64308d876505cc7c77", size = 224528, upload-time = "2025-10-15T17:35:28.299Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/27/9cf8c6de4ae71e9c98ec96b3304449d5d0cd36ec3b95e66b6e7f58a9e571/a2a_sdk-0.3.7-py3-none-any.whl", hash = "sha256:0813b8fd7add427b2b56895cf28cae705303cf6d671b305c0aac69987816e03e", size = 137957, upload-time = "2025-09-23T16:27:27.546Z" }, + { url = "https://files.pythonhosted.org/packages/34/ee/53b2da6d2768b136f996b8c6ab00ebcc44852f9a33816a64deaca6b279fe/a2a_sdk-0.3.9-py3-none-any.whl", hash = "sha256:7ed03a915bae98def46ea0313786da0a7a488346c3dc8af88407bb0b2a763926", size = 139027, upload-time = "2025-10-15T17:35:26.628Z" }, ] [[package]] @@ -168,7 +168,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", specifier = "==0.3.7" }, + { name = "a2a-sdk", specifier = "==0.3.9" }, { name = "anyio", specifier = ">=4.9.0" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "fastapi", specifier = ">=0.116.1" }, diff --git a/apps/beeai-server/pyproject.toml b/apps/beeai-server/pyproject.toml index bb88e7f09..f2829ee70 100644 --- a/apps/beeai-server/pyproject.toml +++ b/apps/beeai-server/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" authors = [{ name = "IBM Corp." }] requires-python = "==3.12.*" dependencies = [ - "a2a-sdk~=0.3.5", + "a2a-sdk~=0.3.9", "aiohttp>=3.11.11", "anyio>=4.9.0", "asgiref>=3.8.1", @@ -42,6 +42,8 @@ dependencies = [ "openai>=1.97.0", "authlib>=1.6.4", "async-lru>=2.0.5", + "starlette>=0.48.0", + "sse-starlette>=3.0.2", "exceptiongroup>=1.3.0", ] diff --git a/apps/beeai-server/src/beeai_server/api/routes/a2a.py b/apps/beeai-server/src/beeai_server/api/routes/a2a.py index 4dd3b3221..1bb440928 100644 --- a/apps/beeai-server/src/beeai_server/api/routes/a2a.py +++ b/apps/beeai-server/src/beeai_server/api/routes/a2a.py @@ -2,14 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Annotated -from urllib.parse import urljoin, urlparse +from urllib.parse import urljoin from uuid import UUID import fastapi import fastapi.responses -from a2a.types import AgentCard, TransportProtocol +from a2a.server.apps import A2AFastAPIApplication +from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.types import AgentCard, AgentInterface, TransportProtocol from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH -from fastapi import Depends, Request +from fastapi import Depends, HTTPException, Request from beeai_server.api.dependencies import ( A2AProxyServiceDependency, @@ -19,45 +21,21 @@ from beeai_server.domain.models.permissions import AuthorizedUser from beeai_server.service_layer.services.a2a import A2AServerResponse -_SUPPORTED_TRANSPORTS = [TransportProtocol.jsonrpc, TransportProtocol.http_json] - - router = fastapi.APIRouter() -def _create_proxy_url(url: str, *, proxy_base: str) -> str: - return urljoin(proxy_base, urlparse(url).path.lstrip("/")) - - def create_proxy_agent_card(agent_card: AgentCard, *, provider_id: UUID, request: Request) -> AgentCard: - proxy_base = str(request.url_for(proxy_request.__name__, provider_id=provider_id, path="")) - proxy_interfaces = ( - [ - interface.model_copy(update={"url": _create_proxy_url(interface.url, proxy_base=proxy_base)}) - for interface in agent_card.additional_interfaces - if interface.transport in _SUPPORTED_TRANSPORTS - ] - if agent_card.additional_interfaces is not None - else None + proxy_base = str(request.url_for(a2a_proxy_jsonrpc_transport.__name__, provider_id=provider_id)) + return agent_card.model_copy( + update={ + "preferred_transport": TransportProtocol.jsonrpc, + "url": proxy_base, + "additional_interfaces": [ + AgentInterface(transport=TransportProtocol.http_json, url=urljoin(proxy_base, "http")), + AgentInterface(transport=TransportProtocol.jsonrpc, url=proxy_base), + ], + } ) - if agent_card.preferred_transport in _SUPPORTED_TRANSPORTS: - return agent_card.model_copy( - update={ - "url": _create_proxy_url(agent_card.url, proxy_base=proxy_base), - "additional_interfaces": proxy_interfaces, - } - ) - elif proxy_interfaces: - interface = proxy_interfaces[0] - return agent_card.model_copy( - update={ - "url": interface.url, - "preferred_transport": interface.transport, - "additional_interfaces": proxy_interfaces, - } - ) - else: - raise RuntimeError("Provider doesn't have any transport supported by the proxy.") def _to_fastapi(response: A2AServerResponse): @@ -79,15 +57,45 @@ async def get_agent_card( return create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) -@router.api_route("/{provider_id}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) -@router.api_route("/{provider_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) -async def proxy_request( +@router.post("/{provider_id}") +@router.post("/{provider_id}/") +async def a2a_proxy_jsonrpc_transport( + provider_id: UUID, + request: fastapi.requests.Request, + a2a_proxy: A2AProxyServiceDependency, + provider_service: ProviderServiceDependency, + user: Annotated[AuthorizedUser, Depends(RequiresPermissions(a2a_proxy={"*"}))], +): + provider = await provider_service.get_provider(provider_id=provider_id) + agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) + + handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user) + app = A2AFastAPIApplication(agent_card=agent_card, http_handler=handler) + return await app._handle_requests(request) + + +@router.api_route("/{provider_id}/http", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) +@router.api_route( + "/{provider_id}/http/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] +) +async def a2a_proxy_http_transport( provider_id: UUID, request: fastapi.requests.Request, a2a_proxy: A2AProxyServiceDependency, - _: Annotated[AuthorizedUser, Depends(RequiresPermissions(a2a_proxy={"*"}))], + provider_service: ProviderServiceDependency, + user: Annotated[AuthorizedUser, Depends(RequiresPermissions(a2a_proxy={"*"}))], path: str = "", ): - client = await a2a_proxy.get_proxy_client(provider_id=provider_id) - response = await client.send_request(method=request.method, url=f"/{path}", content=request.stream()) - return _to_fastapi(response) + provider = await provider_service.get_provider(provider_id=provider_id) + agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) + + handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user) + adapter = RESTAdapter(agent_card=agent_card, http_handler=handler) + + if not (handler := adapter.routes().get((f"/{path.rstrip('/')}", request.method), None)): + raise HTTPException(status_code=404, detail="Not found") + + return await handler(request) + + +# TODO: extra a2a routes are not supported diff --git a/apps/beeai-server/src/beeai_server/configuration.py b/apps/beeai-server/src/beeai_server/configuration.py index 574555f52..1a3e886d3 100644 --- a/apps/beeai-server/src/beeai_server/configuration.py +++ b/apps/beeai-server/src/beeai_server/configuration.py @@ -230,6 +230,11 @@ class ContextConfiguration(BaseModel): resource_expire_after_days: int = 7 # Expires files and vector_stores attached to a context +class A2AProxyConfiguration(BaseModel): + # Expires a2a_request_tasks and a2a_request_contexts (WARNING: has security implications!) + requests_expire_after_days: int = 14 + + class FeatureConfiguration(BaseModel): generate_conversation_title: bool = True provider_builds: bool = True @@ -265,6 +270,7 @@ class Configuration(BaseSettings): vector_stores: VectorStoresConfiguration = Field(default_factory=VectorStoresConfiguration) text_extraction: DoclingExtractionConfiguration = Field(default_factory=DoclingExtractionConfiguration) context: ContextConfiguration = Field(default_factory=ContextConfiguration) + a2a_proxy: A2AProxyConfiguration = Field(default_factory=A2AProxyConfiguration) k8s_namespace: str | None = None k8s_kubeconfig: Path | None = None diff --git a/apps/beeai-server/src/beeai_server/domain/models/a2a_request.py b/apps/beeai-server/src/beeai_server/domain/models/a2a_request.py new file mode 100644 index 000000000..87df06bd6 --- /dev/null +++ b/apps/beeai-server/src/beeai_server/domain/models/a2a_request.py @@ -0,0 +1,14 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from uuid import UUID + +from pydantic import AwareDatetime, BaseModel + + +class A2ARequestTask(BaseModel): + task_id: str + created_by: UUID + provider_id: UUID + created_at: AwareDatetime + last_accessed_at: AwareDatetime diff --git a/apps/beeai-server/src/beeai_server/domain/repositories/a2a_request.py b/apps/beeai-server/src/beeai_server/domain/repositories/a2a_request.py new file mode 100644 index 000000000..c547661af --- /dev/null +++ b/apps/beeai-server/src/beeai_server/domain/repositories/a2a_request.py @@ -0,0 +1,24 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 +from datetime import timedelta +from typing import Protocol, runtime_checkable +from uuid import UUID + +from beeai_server.domain.models.a2a_request import A2ARequestTask + + +@runtime_checkable +class IA2ARequestRepository(Protocol): + async def track_request_ids_ownership( + self, + user_id: UUID, + provider_id: UUID, + task_id: str | None = None, + context_id: str | None = None, + allow_task_creation: bool = False, + ) -> None: ... + + async def get_task(self, *, task_id: str, user_id: UUID) -> A2ARequestTask: ... + + async def delete_tasks(self, *, older_than: timedelta) -> int: ... + async def delete_contexts(self, *, older_than: timedelta) -> int: ... diff --git a/apps/beeai-server/src/beeai_server/exceptions.py b/apps/beeai-server/src/beeai_server/exceptions.py index bdb952b76..ae168aecf 100644 --- a/apps/beeai-server/src/beeai_server/exceptions.py +++ b/apps/beeai-server/src/beeai_server/exceptions.py @@ -55,6 +55,20 @@ def __init__( super().__init__(f"{entity} with {attribute} {id} not found", status_code) +class ForbiddenUpdateError(PlatformError): + entity: str + id: UUID | str + attribute: str + + def __init__( + self, entity: str, id: UUID | str, status_code: int = status.HTTP_404_NOT_FOUND, attribute: str = "id" + ): + self.entity = entity + self.id = id + self.attribute = attribute + super().__init__("Insufficient permissions", status_code) + + class InvalidVectorDimensionError(PlatformError): ... diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/29762474b358_.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/29762474b358_.py new file mode 100644 index 000000000..4d8308219 --- /dev/null +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/29762474b358_.py @@ -0,0 +1,53 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +"""add a2a request tracking tables + +Revision ID: 29762474b358 +Revises: 28725d931ca5 +Create Date: 2025-10-20 09:44:21.015314 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "29762474b358" +down_revision: str | None = "28725d931ca5" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "a2a_request_contexts", + sa.Column("context_id", sa.String(length=256), nullable=False), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("provider_id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("context_id"), + ) + op.create_table( + "a2a_request_tasks", + sa.Column("task_id", sa.String(length=256), nullable=False), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("provider_id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("task_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("a2a_request_tasks") + op.drop_table("a2a_request_contexts") + # ### end Alembic commands ### diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/46ec8881ac4c_.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/46ec8881ac4c_.py index 2d5f55663..688583ec5 100644 --- a/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/46ec8881ac4c_.py +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/46ec8881ac4c_.py @@ -1,7 +1,7 @@ # Copyright 2025 © BeeAI a Series of LF Projects, LLC # SPDX-License-Identifier: Apache-2.0 -"""empty message +"""add model providers and configurations Revision ID: 46ec8881ac4c Revises: 5dec926744d0 diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/requests.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/requests.py new file mode 100644 index 000000000..b5183ea20 --- /dev/null +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/requests.py @@ -0,0 +1,166 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from datetime import timedelta +from uuid import UUID + +from kink import inject +from sqlalchemy import UUID as SQL_UUID +from sqlalchemy import Boolean, Column, DateTime, Row, String, Table, bindparam, text +from sqlalchemy.ext.asyncio import AsyncConnection + +from beeai_server.domain.models.a2a_request import A2ARequestTask +from beeai_server.domain.repositories.a2a_request import IA2ARequestRepository +from beeai_server.exceptions import EntityNotFoundError, ForbiddenUpdateError +from beeai_server.infrastructure.persistence.repositories.db_metadata import metadata +from beeai_server.utils.utils import utc_now + +a2a_request_tasks_table = Table( + "a2a_request_tasks", + metadata, + Column("task_id", String(256), primary_key=True), + Column("created_by", SQL_UUID, nullable=False), # not using reference integrity for performance + Column("provider_id", SQL_UUID, nullable=False), # not using reference integrity for performance + Column("created_at", DateTime(timezone=True), nullable=False), + Column("last_accessed_at", DateTime(timezone=True), nullable=False), +) + +a2a_request_contexts_table = Table( + "a2a_request_contexts", + metadata, + Column("context_id", String(256), primary_key=True), + Column("created_by", SQL_UUID, nullable=False), # not using reference integrity for performance + Column("provider_id", SQL_UUID, nullable=False), # not using reference integrity for performance + Column("created_at", DateTime(timezone=True), nullable=False), + Column("last_accessed_at", DateTime(timezone=True), nullable=False), +) + + +@inject +class SqlAlchemyA2ARequestRepository(IA2ARequestRepository): + def __init__(self, connection: AsyncConnection): + self._connection = connection + + def _to_task(self, row: Row) -> A2ARequestTask: + return A2ARequestTask.model_validate( + { + "task_id": row.task_id, + "created_by": row.created_by, + "provider_id": row.provider_id, + "created_at": row.created_at, + "last_accessed_at": row.last_accessed_at, + } + ) + + async def track_request_ids_ownership( + self, + user_id: UUID, + provider_id: UUID, + task_id: str | None = None, + context_id: str | None = None, + allow_task_creation: bool = False, + ) -> None: + """ + Verify ownership and record/update identifiers in a SINGLE query. + + Args: + allow_task_creation: If False, task_id must already exist in DB (client->server requests). + If True, task_id can be created (server responses). + """ + + # This handles all cases: + # - New task_id/context_id: Creates ownership record (if allowed) + # - Existing owned: Updates last_accessed_at and returns true + # - Existing owned by OTHER user: ON CONFLICT WHERE clause prevents update, returns false + + now = utc_now() + + query = text(""" + WITH task_insert AS ( + INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) + SELECT :task_id, :user_id, :provider_id, :now, :now + WHERE :task_id IS NOT NULL AND :allow_task_creation = true + ON CONFLICT (task_id) DO NOTHING + RETURNING true as inserted), + task_update AS ( + UPDATE a2a_request_tasks + SET last_accessed_at = :now + WHERE task_id = :task_id AND created_by = :user_id + RETURNING true as updated), + context_insert AS ( + INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) + SELECT :context_id, :user_id, :provider_id, :now, :now + WHERE :context_id IS NOT NULL + ON CONFLICT (context_id) DO NOTHING + RETURNING true as inserted), + context_update AS ( + UPDATE a2a_request_contexts + SET last_accessed_at = :now + WHERE context_id = :context_id AND created_by = :user_id + RETURNING true as updated) + SELECT CASE + WHEN :task_id IS NULL THEN true + WHEN EXISTS (SELECT 1 FROM task_insert) THEN true + WHEN EXISTS (SELECT 1 FROM task_update) THEN true + ELSE false + END as task_authorized, + CASE + WHEN :context_id IS NULL THEN true + WHEN EXISTS (SELECT 1 FROM context_insert) THEN true + WHEN EXISTS (SELECT 1 FROM context_update) THEN true + ELSE false + END as context_authorized + """).bindparams( + bindparam("task_id", type_=String), + bindparam("context_id", type_=String), + bindparam("user_id", type_=SQL_UUID()), + bindparam("provider_id", type_=SQL_UUID()), + bindparam("allow_task_creation", type_=Boolean), + bindparam("now", type_=DateTime(timezone=True)), + ) + + result = await self._connection.execute( + query, + { + "task_id": task_id, + "context_id": context_id, + "user_id": user_id, + "provider_id": provider_id, + "allow_task_creation": allow_task_creation, + "now": now, + }, + execution_options={"compiled_cache": None}, + ) + + if not (row := result.first()): + raise RuntimeError("Unexpected query result") + if not row.task_authorized: + assert task_id + raise EntityNotFoundError(entity="a2a_request_task", id=task_id) + if not row.context_authorized: + assert context_id + raise ForbiddenUpdateError(entity="a2a_request_context", id=context_id) + + async def get_task(self, *, task_id: str, user_id: UUID) -> A2ARequestTask: + """Get a task by task_id if owned by the user.""" + query = a2a_request_tasks_table.select().where( + a2a_request_tasks_table.c.task_id == task_id, a2a_request_tasks_table.c.created_by == user_id + ) + result = await self._connection.execute(query) + if not (row := result.fetchone()): + raise EntityNotFoundError(entity="a2a_request_task", id=task_id) + return self._to_task(row) + + async def delete_tasks(self, *, older_than: timedelta) -> int: + query = a2a_request_tasks_table.delete().where( + a2a_request_tasks_table.c.last_accessed_at < utc_now() - older_than + ) + result = await self._connection.execute(query) + return result.rowcount + + async def delete_contexts(self, *, older_than: timedelta) -> int: + query = a2a_request_contexts_table.delete().where( + a2a_request_contexts_table.c.last_accessed_at < utc_now() - older_than + ) + result = await self._connection.execute(query) + return result.rowcount diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/unit_of_work.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/unit_of_work.py index 39c6cae95..12bf794f6 100644 --- a/apps/beeai-server/src/beeai_server/infrastructure/persistence/unit_of_work.py +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/unit_of_work.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncTransaction from beeai_server.configuration import Configuration +from beeai_server.domain.repositories.a2a_request import IA2ARequestRepository from beeai_server.domain.repositories.configurations import IConfigurationsRepository from beeai_server.domain.repositories.context import IContextRepository from beeai_server.domain.repositories.env import IEnvVariableRepository @@ -23,6 +24,7 @@ from beeai_server.infrastructure.persistence.repositories.model_provider import SqlAlchemyModelProviderRepository from beeai_server.infrastructure.persistence.repositories.provider import SqlAlchemyProviderRepository from beeai_server.infrastructure.persistence.repositories.provider_build import SqlAlchemyProviderBuildRepository +from beeai_server.infrastructure.persistence.repositories.requests import SqlAlchemyA2ARequestRepository from beeai_server.infrastructure.persistence.repositories.user import SqlAlchemyUserRepository from beeai_server.infrastructure.persistence.repositories.user_feedback import SqlAlchemyUserFeedbackRepository from beeai_server.infrastructure.persistence.repositories.vector_store import SqlAlchemyVectorStoreRepository @@ -36,6 +38,7 @@ class SQLAlchemyUnitOfWork(IUnitOfWork): Works purely with SQLAlchemy Core objects (insert(), update(), text(), …). """ + a2a_requests: IA2ARequestRepository providers: IProviderRepository model_providers: IModelProviderRepository contexts: IContextRepository @@ -60,6 +63,7 @@ async def __aenter__(self) -> Self: self._connection = await self._engine.connect() self._transaction = await self._connection.begin() + self.a2a_requests = SqlAlchemyA2ARequestRepository(self._connection) self.providers = SqlAlchemyProviderRepository(self._connection) self.model_providers = SqlAlchemyModelProviderRepository(self._connection) self.provider_builds = SqlAlchemyProviderBuildRepository(self._connection) diff --git a/apps/beeai-server/src/beeai_server/jobs/crons/cleanup.py b/apps/beeai-server/src/beeai_server/jobs/crons/cleanup.py index 7dd1dc993..e57228c34 100644 --- a/apps/beeai-server/src/beeai_server/jobs/crons/cleanup.py +++ b/apps/beeai-server/src/beeai_server/jobs/crons/cleanup.py @@ -6,6 +6,7 @@ from procrastinate import Blueprint, JobContext, builtin_tasks from beeai_server.jobs.queues import Queues +from beeai_server.service_layer.services.a2a import A2AProxyService from beeai_server.service_layer.services.contexts import ContextService blueprint = Blueprint() @@ -14,7 +15,7 @@ @blueprint.periodic(cron="5 * * * *") -@blueprint.task(queueing_lock="cleanup_expired_vector_stores", queue=str(Queues.CRON_CLEANUP)) +@blueprint.task(queueing_lock="cleanup_expired_context_resources", queue=str(Queues.CRON_CLEANUP)) @inject async def cleanup_expired_context_resources(timestamp: int, context: ContextService) -> None: """Delete resources of contexts that haven't been used for several days.""" @@ -22,6 +23,15 @@ async def cleanup_expired_context_resources(timestamp: int, context: ContextServ logger.info(f"Deleted: {deleted_stats}") +@blueprint.periodic(cron="10 * * * *") +@blueprint.task(queueing_lock="cleanup_expired_a2a_requests", queue=str(Queues.CRON_CLEANUP)) +@inject +async def cleanup_expired_a2a_tasks(timestamp: int, a2a_proxy: A2AProxyService) -> None: + """Delete tracked request objects that haven't been used for several days.""" + deleted_stats = await a2a_proxy.expire_requests() + logger.info(f"Deleted: {deleted_stats}") + + @blueprint.periodic(cron="*/10 * * * *") @blueprint.task(queueing_lock="remove_old_jobs", queue=str(Queues.CRON_CLEANUP), pass_context=True) async def remove_old_jobs(context: JobContext, timestamp: int): diff --git a/apps/beeai-server/src/beeai_server/service_layer/services/a2a.py b/apps/beeai-server/src/beeai_server/service_layer/services/a2a.py index f377ef480..b7a731753 100644 --- a/apps/beeai-server/src/beeai_server/service_layer/services/a2a.py +++ b/apps/beeai-server/src/beeai_server/service_layer/services/a2a.py @@ -1,29 +1,96 @@ # Copyright 2025 © BeeAI a Series of LF Projects, LLC # SPDX-License-Identifier: Apache-2.0 import functools +import inspect import logging -from collections.abc import AsyncIterable -from contextlib import AsyncExitStack +import uuid +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable +from contextlib import asynccontextmanager from datetime import timedelta -from typing import Any, NamedTuple, cast +from typing import NamedTuple, cast +from urllib.parse import urljoin, urlparse from uuid import UUID import httpx -from httpx import AsyncByteStream +from a2a.client import ClientCallContext, ClientConfig, ClientFactory +from a2a.client.base_client import BaseClient +from a2a.client.errors import A2AClientJSONRPCError +from a2a.client.transports.base import ClientTransport +from a2a.server.context import ServerCallContext +from a2a.server.events import Event +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, + InvalidRequestError, + ListTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, + TransportProtocol, +) +from a2a.utils.errors import ServerError from kink import inject +from pydantic import HttpUrl from structlog.contextvars import bind_contextvars, unbind_contextvars from beeai_server.configuration import Configuration from beeai_server.domain.models.provider import ( NetworkProviderLocation, + Provider, ProviderDeploymentState, ) +from beeai_server.domain.models.user import User +from beeai_server.exceptions import EntityNotFoundError, ForbiddenUpdateError from beeai_server.service_layer.deployment_manager import IProviderDeploymentManager from beeai_server.service_layer.services.users import UserService from beeai_server.service_layer.unit_of_work import IUnitOfWorkFactory logger = logging.getLogger(__name__) +_SUPPORTED_TRANSPORTS = {TransportProtocol.http_json, TransportProtocol.jsonrpc} + + +def _create_deploy_a2a_url(url: str, *, deployment_base: str) -> str: + return urljoin(deployment_base, urlparse(url).path.lstrip("/")) + + +def create_deployment_agent_card(agent_card: AgentCard, *, deployment_base: str) -> AgentCard: + proxy_interfaces = ( + [ + interface.model_copy(update={"url": _create_deploy_a2a_url(interface.url, deployment_base=deployment_base)}) + for interface in agent_card.additional_interfaces + if interface.transport in _SUPPORTED_TRANSPORTS + ] + if agent_card.additional_interfaces is not None + else None + ) + if agent_card.preferred_transport in _SUPPORTED_TRANSPORTS: + return agent_card.model_copy( + update={ + "url": _create_deploy_a2a_url(agent_card.url, deployment_base=deployment_base), + "additional_interfaces": proxy_interfaces, + } + ) + elif proxy_interfaces: + interface = proxy_interfaces[0] + return agent_card.model_copy( + update={ + "url": interface.url, + "preferred_transport": interface.transport, + "additional_interfaces": proxy_interfaces, + } + ) + else: + raise RuntimeError("Provider doesn't have any transport supported by the proxy.") + class A2AServerResponse(NamedTuple): content: bytes | None @@ -33,49 +100,170 @@ class A2AServerResponse(NamedTuple): media_type: str -class ProxyClient: - def __init__(self, client: httpx.AsyncClient): - self._client = client - - @functools.wraps(httpx.AsyncClient.stream) - async def send_request(*args, **kwargs) -> A2AServerResponse: - self = args[0] # extract self for type checking - rest_args: tuple[Any, ...] = args[1:] - exit_stack = AsyncExitStack() +def _handle_exception[T: Callable](fn: T) -> T: + @functools.wraps(fn) + async def _fn(*args, **kwargs): try: - client = await exit_stack.enter_async_context(self._client) - resp: httpx.Response = await exit_stack.enter_async_context(client.stream(*rest_args, **kwargs)) - - try: - content_type = resp.headers["content-type"] - is_stream = content_type.startswith("text/event-stream") - except KeyError: - content_type = None - is_stream = False - - async def stream_fn(): - try: - async for event in cast(AsyncByteStream, resp.stream): - yield event - finally: - await exit_stack.pop_all().aclose() - - common = { - "status_code": resp.status_code, - "headers": resp.headers, - "media_type": content_type, - } - if is_stream: - return A2AServerResponse(content=None, stream=stream_fn(), **common) - else: - try: - await resp.aread() - return A2AServerResponse(stream=None, content=resp.content, **common) - finally: - await exit_stack.pop_all().aclose() - except BaseException: - await exit_stack.pop_all().aclose() + return await fn(*args, **kwargs) + except EntityNotFoundError as e: + if "task" in e.entity: + raise ServerError(error=TaskNotFoundError()) from e raise + except ForbiddenUpdateError as e: + raise ServerError(error=InvalidRequestError(message=str(e))) from e + except A2AClientJSONRPCError as e: + raise ServerError(error=e.error) from e + + @functools.wraps(fn) + async def _fn_iter(*args, **kwargs): + try: + async for item in fn(*args, **kwargs): + yield item + except A2AClientJSONRPCError as e: + raise ServerError(error=e.error) from e + + return _fn_iter if inspect.isasyncgenfunction(fn) else _fn # pyright: ignore [reportReturnType] + + +class ProxyRequestHandler(RequestHandler): + def __init__(self, agent_card: AgentCard, provider_id: UUID, uow: IUnitOfWorkFactory, user: User): + self._agent_card = agent_card + self._provider_id = provider_id + self._user = user + self._uow = uow + + @asynccontextmanager + async def _client_transport(self) -> AsyncIterator[ClientTransport]: + async with httpx.AsyncClient(follow_redirects=True, timeout=timedelta(hours=1).total_seconds()) as httpx_client: + client: BaseClient = cast( + BaseClient, + ClientFactory(config=ClientConfig(httpx_client=httpx_client)).create(card=self._agent_card), + ) + yield client._transport + + async def _check_task(self, task_id: str): + async with self._uow() as uow: + await uow.a2a_requests.get_task(task_id=task_id, user_id=self._user.id) + + async def _check_and_record_request( + self, task_id: str | None = None, context_id: str | None = None, allow_task_creation: bool = False + ): + async with self._uow() as uow: + # Consider: a bit paranoid check + # if context_id: + # with suppress(ValueError, EntityNotFoundError): + # context_uuid = UUID(context_id) + # context = await uow.contexts.get(context_id=context_uuid) + # if context.created_by != self._user.id: + # # attempt to claim context owned by another user + # raise ForbiddenUpdateError(entity="a2a_request_context", id=context_id) + await uow.a2a_requests.track_request_ids_ownership( + user_id=self._user.id, + provider_id=self._provider_id, + task_id=task_id, + context_id=context_id, + allow_task_creation=allow_task_creation, + ) + await uow.commit() + + def _forward_context(self, context: ServerCallContext | None = None) -> ClientCallContext: + return ClientCallContext(state={**(context.state if context else {}), "user_id": self._user.id}) + + @_handle_exception + async def on_get_task(self, params: TaskQueryParams, context: ServerCallContext | None = None) -> Task | None: + await self._check_task(params.id) + async with self._client_transport() as transport: + return await transport.get_task(params, context=self._forward_context(context)) + + @_handle_exception + async def on_cancel_task(self, params: TaskIdParams, context: ServerCallContext | None = None) -> Task | None: + await self._check_task(params.id) + async with self._client_transport() as transport: + return await transport.cancel_task(params, context=self._forward_context(context)) + + @_handle_exception + async def on_message_send( + self, params: MessageSendParams, context: ServerCallContext | None = None + ) -> Task | Message: + # we set task_id and context_id if not configured + params.message.context_id = params.message.context_id or str(uuid.uuid4()) + await self._check_and_record_request(params.message.task_id, params.message.context_id) + + async with self._client_transport() as transport: + response = await transport.send_message(params, context=self._forward_context(context)) + match response: + case Task(id=task_id) | Message(task_id=task_id): + if params.message.task_id is None and task_id: + await self._check_and_record_request( + task_id, params.message.context_id, allow_task_creation=True + ) + return response + + @_handle_exception + async def on_message_send_stream( + self, params: MessageSendParams, context: ServerCallContext | None = None + ) -> AsyncGenerator[Event]: + # we set task_id and context_id if not configured + params.message.context_id = params.message.context_id or str(uuid.uuid4()) + await self._check_and_record_request(params.message.task_id, params.message.context_id) + seen_tasks = {params.message.task_id} if params.message.task_id else set() + + async with self._client_transport() as transport: + async for event in transport.send_message_streaming(params, context=self._forward_context(context)): + match event: + case ( + TaskStatusUpdateEvent(task_id=task_id, context_id=context_id) + | TaskArtifactUpdateEvent(task_id=task_id, context_id=context_id) + | Task(id=task_id, context_id=context_id) + | Message(task_id=task_id, context_id=context_id) + ): + if context_id != params.message.context_id: + raise RuntimeError(f"Unexpected context_id returned from the agent: {context_id}") + if task_id and task_id not in seen_tasks: + await self._check_and_record_request( + task_id=task_id, context_id=context_id, allow_task_creation=True + ) + seen_tasks.add(task_id) + yield event + + @_handle_exception + async def on_set_task_push_notification_config( + self, params: TaskPushNotificationConfig, context: ServerCallContext | None = None + ) -> TaskPushNotificationConfig: + await self._check_task(params.task_id) + async with self._client_transport() as transport: + return await transport.set_task_callback(params) + + @_handle_exception + async def on_get_task_push_notification_config( + self, params: TaskIdParams | GetTaskPushNotificationConfigParams, context: ServerCallContext | None = None + ) -> TaskPushNotificationConfig: + await self._check_task(params.id) + async with self._client_transport() as transport: + if isinstance(params, TaskIdParams): + params = GetTaskPushNotificationConfigParams(id=params.id, metadata=params.metadata) + return await transport.get_task_callback(params, context=self._forward_context(context)) + + @_handle_exception + async def on_resubscribe_to_task( + self, params: TaskIdParams, context: ServerCallContext | None = None + ) -> AsyncGenerator[Event]: + await self._check_task(params.id) + async with self._client_transport() as transport: + async for event in transport.resubscribe(params): + yield event + + @_handle_exception + async def on_list_task_push_notification_config( + self, params: ListTaskPushNotificationConfigParams, context: ServerCallContext | None = None + ) -> list[TaskPushNotificationConfig]: + raise NotImplementedError("This is not supported by the client transport yet") + + @_handle_exception + async def on_delete_task_push_notification_config( + self, params: DeleteTaskPushNotificationConfigParams, context: ServerCallContext | None = None + ) -> None: + raise NotImplementedError("This is not supported by the client transport yet") @inject @@ -93,8 +281,21 @@ def __init__( self._uow = uow self._user_service = user_service self._config = configuration + self._expire_requests_after = timedelta(days=configuration.a2a_proxy.requests_expire_after_days) + + async def get_request_handler(self, *, provider: Provider, user: User) -> RequestHandler: + url = await self.ensure_agent(provider_id=provider.id) + agent_card = create_deployment_agent_card(provider.agent_card, deployment_base=str(url)) + return ProxyRequestHandler(agent_card=agent_card, provider_id=provider.id, uow=self._uow, user=user) + + async def expire_requests(self) -> dict[str, int]: + async with self._uow() as uow: + n_tasks = await uow.a2a_requests.delete_tasks(older_than=self._expire_requests_after) + n_ctx = await uow.a2a_requests.delete_contexts(older_than=self._expire_requests_after) + await uow.commit() + return {"tasks": n_tasks, "contexts": n_ctx} - async def get_proxy_client(self, *, provider_id: UUID) -> ProxyClient: + async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl: try: bind_contextvars(provider=provider_id) @@ -104,9 +305,8 @@ async def get_proxy_client(self, *, provider_id: UUID) -> ProxyClient: await uow.commit() if not provider.managed: - if not isinstance(provider.source, NetworkProviderLocation): - raise ValueError(f"Unmanaged provider location type is not supported: {type(provider.source)}") - return ProxyClient(httpx.AsyncClient(base_url=str(provider.source.a2a_url), timeout=None)) + assert isinstance(provider.source, NetworkProviderLocation) + return provider.source.a2a_url provider_url = await self._deploy_manager.get_provider_url(provider_id=provider.id) [state] = await self._deploy_manager.state(provider_ids=[provider.id]) @@ -134,6 +334,6 @@ async def get_proxy_client(self, *, provider_id: UUID) -> ProxyClient: logger.info("Waiting for provider to start up...") await self._deploy_manager.wait_for_startup(provider_id=provider.id, timeout=self.STARTUP_TIMEOUT) logger.info("Provider is ready...") - return ProxyClient(httpx.AsyncClient(base_url=str(provider_url), timeout=None)) + return provider_url finally: unbind_contextvars("provider") diff --git a/apps/beeai-server/src/beeai_server/service_layer/unit_of_work.py b/apps/beeai-server/src/beeai_server/service_layer/unit_of_work.py index ba2e9f4cb..5c3f4ff2d 100644 --- a/apps/beeai-server/src/beeai_server/service_layer/unit_of_work.py +++ b/apps/beeai-server/src/beeai_server/service_layer/unit_of_work.py @@ -3,6 +3,7 @@ from typing import Protocol, Self +from beeai_server.domain.repositories.a2a_request import IA2ARequestRepository from beeai_server.domain.repositories.configurations import IConfigurationsRepository from beeai_server.domain.repositories.context import IContextRepository from beeai_server.domain.repositories.env import IEnvVariableRepository @@ -18,6 +19,7 @@ class IUnitOfWork(Protocol): providers: IProviderRepository provider_builds: IProviderBuildRepository + a2a_requests: IA2ARequestRepository contexts: IContextRepository files: IFileRepository env: IEnvVariableRepository diff --git a/apps/beeai-server/tasks.toml b/apps/beeai-server/tasks.toml index 3f808657f..6fcb8bb8e 100644 --- a/apps/beeai-server/tasks.toml +++ b/apps/beeai-server/tasks.toml @@ -167,7 +167,7 @@ run = """ #!/bin/bash VM_NAME='{{option(name="vm-name", default="beeai-local-dev")}}' {{ mise_bin }} run beeai-server:dev:disconnect --vm-name="$VM_NAME" -{{ mise_bin }} run beeai-platform:stop --vm-name=beeai-local-dev +{{ mise_bin }} run beeai-platform:stop --vm-name="$VM_NAME" """ ["beeai-server:dev:delete"] @@ -200,7 +200,9 @@ VM_NAME='{{option(name="vm-name", default="beeai-local-dev")}}' eval "$( {{ mise_bin }} run beeai-platform:shell --vm-name="$VM_NAME" )" tele="telepresence --use .*${NAMESPACE}.*" -$tele uninstall --all-agents || true +if ! ($tele status --output json | jq '.user_daemon.status' | grep -q "Not connected"); then + $tele uninstall --all-agents || true +fi $tele quit """ @@ -220,9 +222,14 @@ run = """ --vm-name=beeai-local-test \ --set externalRegistries=null \ --set providerBuilds.enabled=true \ - --set localDockerRegistry.enabled=true + --set localDockerRegistry.enabled=true \ + {{arg(name="cli-args", var=true, default="--")}} """ +["beeai-server:dev:test:stop"] +dir = "{{config_root}}/apps/beeai-server" +run = "{{ mise_bin }} run beeai-server:dev:stop --vm-name=beeai-local-test" + ["beeai-server:dev:test:delete"] dir = "{{config_root}}/apps/beeai-server" run = "{{ mise_bin }} run beeai-server:dev:delete --vm-name=beeai-local-test" diff --git a/apps/beeai-server/tests/e2e/agents/conftest.py b/apps/beeai-server/tests/e2e/agents/conftest.py index 71c2e3283..d2492e01c 100644 --- a/apps/beeai-server/tests/e2e/agents/conftest.py +++ b/apps/beeai-server/tests/e2e/agents/conftest.py @@ -2,15 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator, Callable from contextlib import asynccontextmanager +from typing import Any -import httpx import pytest -from a2a.client import Client, ClientFactory +from a2a.client import Client from a2a.types import AgentCard -from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH -from beeai_sdk.platform import PlatformClient +from beeai_sdk.platform import PlatformClient, Provider from beeai_sdk.server import Server from beeai_sdk.server.store.context_store import ContextStore from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed @@ -20,7 +19,10 @@ @asynccontextmanager async def run_server( - server: Server, port: int, context_store: ContextStore | None = None + server: Server, + port: int, + a2a_client_factory: Callable[[AgentCard | dict[str, Any]], AsyncIterator[Client]], + context_store: ContextStore | None = None, ) -> AsyncGenerator[tuple[Server, Client]]: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -33,32 +35,30 @@ async def run_server( ) try: - async for attempt in AsyncRetrying(stop=stop_after_attempt(10), wait=wait_fixed(0.1), reraise=True): + async for attempt in AsyncRetrying(stop=stop_after_attempt(20), wait=wait_fixed(0.3), reraise=True): with attempt: if not server.server or not server.server.started: raise ConnectionError("Server hasn't started yet") - base_url = f"http://localhost:{port}" - async with httpx.AsyncClient(timeout=None) as httpx_client: - from a2a.client import ClientConfig - - card_resp = await httpx_client.get(base_url + AGENT_CARD_WELL_KNOWN_PATH) - card_resp.raise_for_status() - card = AgentCard.model_validate(card_resp.json()) - client = ClientFactory(ClientConfig(httpx_client=httpx_client)).create(card) + providers = [p for p in await Provider.list() if f":{port}" in p.source] + assert len(providers) == 1, "Provider not registered" + async with a2a_client_factory(providers[0].agent_card) as client: yield server, client finally: server.should_exit = True @pytest.fixture -def create_server_with_agent(free_port, test_configuration: TestConfiguration): +@pytest.mark.usefixtures("setup_platform_client") +def create_server_with_agent(free_port, test_configuration: TestConfiguration, a2a_client_factory): """Factory fixture that creates a server with the given agent function.""" @asynccontextmanager async def _create_server(agent_fn, context_store: ContextStore | None = None): server = Server() server.agent()(agent_fn) - async with run_server(server, free_port, context_store=context_store) as (server, client): + async with run_server( + server, free_port, a2a_client_factory=a2a_client_factory, context_store=context_store + ) as (server, client): yield server, client return _create_server diff --git a/apps/beeai-server/tests/e2e/agents/test_agent_starts.py b/apps/beeai-server/tests/e2e/agents/test_agent_starts.py index c5d2f258b..8764d4a81 100644 --- a/apps/beeai-server/tests/e2e/agents/test_agent_starts.py +++ b/apps/beeai-server/tests/e2e/agents/test_agent_starts.py @@ -6,15 +6,28 @@ import pytest from a2a.client.helpers import create_text_message_object from a2a.types import ( + Role, + Task, TaskState, ) -from beeai_sdk.a2a.extensions import LLMFulfillment, LLMServiceExtensionClient, LLMServiceExtensionSpec +from beeai_sdk.a2a.extensions import ( + LLMFulfillment, + LLMServiceExtensionClient, + LLMServiceExtensionSpec, + PlatformApiExtensionClient, + PlatformApiExtensionSpec, +) from beeai_sdk.platform import ModelProvider, Provider -from beeai_sdk.platform.context import Context, Permissions +from beeai_sdk.platform.context import Context, ContextPermissions, Permissions pytestmark = pytest.mark.e2e +def extract_agent_text_from_stream(task: Task) -> str: + assert task.history + return "".join(item.parts[0].root.text for item in task.history if item.role == Role.agent if item.parts) + + @pytest.mark.usefixtures("clean_up", "setup_real_llm", "setup_platform_client") async def test_remote_agent(subtests, a2a_client_factory, get_final_task_from_stream, test_configuration): agent_image = test_configuration.test_agent_image @@ -22,7 +35,10 @@ async def test_remote_agent(subtests, a2a_client_factory, get_final_task_from_st _ = await Provider.create(location=agent_image) providers = await Provider.list() context = await Context.create() - context_token = await context.generate_token(grant_global_permissions=Permissions(llm={"*"})) + context_token = await context.generate_token( + grant_global_permissions=Permissions(llm={"*"}), + grant_context_permissions=ContextPermissions(context_data={"*"}), + ) assert len(providers) == 1 assert providers[0].source == agent_image assert providers[0].agent_card @@ -30,8 +46,11 @@ async def test_remote_agent(subtests, a2a_client_factory, get_final_task_from_st async with a2a_client_factory(providers[0].agent_card) as a2a_client: with subtests.test("run chat agent for the first time"): num_parallel = 3 - message = create_text_message_object(content="Repeat this exactly: 'hello world'") + message = create_text_message_object( + content="Repeat this exactly, use the act tool with final_answer: 'hello world'" + ) spec = LLMServiceExtensionSpec.from_agent_card(providers[0].agent_card) + platform_api_spec = PlatformApiExtensionSpec.from_agent_card(providers[0].agent_card) message.metadata = LLMServiceExtensionClient(spec).fulfillment_metadata( llm_fulfillments={ "default": LLMFulfillment( @@ -40,13 +59,16 @@ async def test_remote_agent(subtests, a2a_client_factory, get_final_task_from_st api_base="{platform_url}/api/v1/openai/", ) } + ) | PlatformApiExtensionClient(platform_api_spec).api_auth_metadata( + auth_token=context_token.token, + expires_at=context_token.expires_at, ) message.context_id = context.id task = await get_final_task_from_stream(a2a_client.send_message(message)) # Verify response assert task.status.state == TaskState.completed, f"Fail: {task.status.message.parts[0].root.text}" - assert "hello world" in task.history[-1].parts[0].root.text + assert "hello world" in extract_agent_text_from_stream(task) # Run 3 requests in parallel (test that each request waits) run_results = await asyncio.gather( @@ -55,9 +77,9 @@ async def test_remote_agent(subtests, a2a_client_factory, get_final_task_from_st for task in run_results: assert task.status.state == TaskState.completed, f"Fail: {task.status.message.parts[0].root.text}" - assert "hello world" in task.history[-1].parts[0].root.text + assert "hello world" in extract_agent_text_from_stream(task) with subtests.test("run chat agent for the second time"): task = await get_final_task_from_stream(a2a_client.send_message(message)) assert task.status.state == TaskState.completed, f"Fail: {task.status.message.parts[0].root.text}" - assert "hello world" in task.history[-1].parts[0].root.text + assert "hello world" in extract_agent_text_from_stream(task) diff --git a/apps/beeai-server/tests/e2e/routes/test_a2a_proxy.py b/apps/beeai-server/tests/e2e/routes/test_a2a_proxy.py index 17d31486b..889e67063 100644 --- a/apps/beeai-server/tests/e2e/routes/test_a2a_proxy.py +++ b/apps/beeai-server/tests/e2e/routes/test_a2a_proxy.py @@ -7,6 +7,7 @@ import contextlib import socket import time +import uuid from contextlib import closing from threading import Thread from typing import Any @@ -36,6 +37,7 @@ SendMessageSuccessResponse, Task, TaskArtifactUpdateEvent, + TaskNotFoundError, TaskPushNotificationConfig, TaskState, TaskStatus, @@ -46,6 +48,7 @@ from a2a.utils.errors import MethodNotImplementedError from fastapi import FastAPI from httpx import Client +from sqlalchemy import text from starlette.applications import Starlette from starlette.authentication import ( AuthCredentials, @@ -59,6 +62,8 @@ from starlette.responses import JSONResponse from starlette.routing import Route +from beeai_server.infrastructure.persistence.repositories.user import users_table + pytestmark = pytest.mark.e2e # === TEST SETUP === @@ -159,6 +164,7 @@ def free_port() -> int: def create_test_server(free_port: int, app: A2AStarletteApplication, test_configuration, clean_up_fn): server_instance: uvicorn.Server | None = None thread: Thread | None = None + app.agent_card.url = f"http://host.docker.internal:{free_port}" def _create_test_server(custom_app: Starlette | FastAPI | None = None) -> Client: custom_app = custom_app or app.build() @@ -192,7 +198,10 @@ def run_server(): error = resp.json() raise RuntimeError(f"Server did not start or register itself correctly: {error}") - return Client(base_url=f"{test_configuration.server_url}/api/v1/a2a/{provider_id}") + return Client( + base_url=f"{test_configuration.server_url}/api/v1/a2a/{provider_id}", + auth=("admin", "test-password"), + ) try: yield _create_test_server @@ -206,12 +215,28 @@ def run_server(): raise RuntimeError("Server did not exit after 5 seconds") +@pytest.fixture +@pytest.mark.usefixtures("clean_up") +async def ensure_mock_task(db_transaction): + res = await db_transaction.execute(users_table.select().where(users_table.c.email == "admin@beeai.dev")) + admin_user = res.fetchone().id + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, NOW(), NOW())" + ), + {"task_id": "task1", "created_by": admin_user, "provider_id": uuid.uuid4()}, + ) + await db_transaction.commit() + + @pytest.fixture def client(create_test_server, test_configuration): """Create a test client with the Starlette app.""" return create_test_server() +# --------------------------------------- TESTS PORTED FROM A2A TEST SUITE --------------------------------------------- # === BASIC FUNCTIONALITY TESTS === @@ -253,6 +278,7 @@ def test_authenticated_extended_agent_card_endpoint_not_supported_fastapi( assert response.status_code == 404 # FastAPI's default for no route +@pytest.mark.skip(reason="Extended agent card is not supported at the moment. # TODO") def test_authenticated_extended_agent_card_endpoint_supported_with_specific_extended_card_starlette( create_test_server, agent_card: AgentCard, @@ -275,6 +301,7 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte assert any(skill["id"] == "skill-extended" for skill in data["skills"]), "Extended skill not found in served card" +@pytest.mark.skip(reason="Extended agent card is not supported at the moment. # TODO") def test_authenticated_extended_agent_card_endpoint_supported_with_specific_extended_card_fastapi( create_test_server, agent_card: AgentCard, @@ -306,11 +333,14 @@ def test_agent_card_custom_url(create_test_server, app: A2AStarletteApplication, assert data["name"] == agent_card.name -def test_starlette_rpc_endpoint_custom_url(create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock): +@pytest.mark.skip(reason="Custom RPC urls are not supported at the moment.") +def test_starlette_rpc_endpoint_custom_url( + create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock, ensure_mock_task +): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + task = Task(id="task1", context_id="ctx1", state="completed", status=task_status) handler.on_get_task.return_value = task client = create_test_server(app.build(rpc_url="/api/rpc")) response = client.post( @@ -327,11 +357,12 @@ def test_starlette_rpc_endpoint_custom_url(create_test_server, app: A2AStarlette assert data["result"]["id"] == "task1" +@pytest.mark.skip(reason="Custom RPC urls are not supported at the moment.") def test_fastapi_rpc_endpoint_custom_url(create_test_server, app: A2AFastAPIApplication, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + task = Task(id="task1", context_id="ctx1", state="completed", status=task_status) handler.on_get_task.return_value = task client = create_test_server(app.build(rpc_url="/api/rpc")) response = client.post( @@ -348,6 +379,7 @@ def test_fastapi_rpc_endpoint_custom_url(create_test_server, app: A2AFastAPIAppl assert data["result"]["id"] == "task1" +@pytest.mark.skip(reason="Custom routes are not supported by the proxy.") def test_starlette_build_with_extra_routes(create_test_server, app: A2AStarletteApplication, agent_card: AgentCard): """Test building the app with additional routes.""" @@ -370,6 +402,7 @@ def custom_handler(request): assert data["name"] == agent_card.name +@pytest.mark.skip(reason="Custom routes are not supported by the proxy.") def test_fastapi_build_with_extra_routes(create_test_server, app: A2AFastAPIApplication, agent_card: AgentCard): """Test building the app with additional routes.""" @@ -395,13 +428,13 @@ def custom_handler(request): # === REQUEST METHODS TESTS === -def test_send_message(create_test_server, handler: mock.AsyncMock, agent_card): +def test_send_message(create_test_server, handler: mock.AsyncMock, agent_card, ensure_mock_task): """Test sending a message.""" # Prepare mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) mock_task = Task( id="task1", - contextId="session-xyz", + context_id="session-xyz", status=task_status, ) handler.on_message_send.return_value = mock_task @@ -439,12 +472,12 @@ def test_send_message(create_test_server, handler: mock.AsyncMock, agent_card): handler.on_message_send.assert_awaited_once() -def test_cancel_task(client: Client, handler: mock.AsyncMock): +async def test_cancel_task(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test cancelling a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) task_status.state = TaskState.canceled # 'cancelled' # - task = Task(id="task1", contextId="ctx1", state="cancelled", status=task_status) + task = Task(id="task1", context_id="ctx1", state="cancelled", status=task_status) handler.on_cancel_task.return_value = task # Send request @@ -468,11 +501,11 @@ def test_cancel_task(client: Client, handler: mock.AsyncMock): handler.on_cancel_task.assert_awaited_once() -def test_get_task(client: Client, handler: mock.AsyncMock): +async def test_get_task(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test getting a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + task = Task(id="task1", context_id="ctx1", state="completed", status=task_status) handler.on_get_task.return_value = task # JSONRPCResponse(root=task) # Send request @@ -495,12 +528,12 @@ def test_get_task(client: Client, handler: mock.AsyncMock): handler.on_get_task.assert_awaited_once() -def test_set_push_notification_config(client: Client, handler: mock.AsyncMock): +def test_set_push_notification_config(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test setting push notification configuration.""" # Setup mock response task_push_config = TaskPushNotificationConfig( - taskId="t2", - pushNotificationConfig=PushNotificationConfig(url="https://example.com", token="secret-token"), + task_id="task1", + push_notification_config=PushNotificationConfig(url="https://example.com", token="secret-token"), ) handler.on_set_task_push_notification_config.return_value = task_push_config @@ -512,7 +545,7 @@ def test_set_push_notification_config(client: Client, handler: mock.AsyncMock): "id": "123", "method": "tasks/pushNotificationConfig/set", "params": { - "taskId": "t2", + "taskId": "task1", "pushNotificationConfig": { "url": "https://example.com", "token": "secret-token", @@ -530,12 +563,12 @@ def test_set_push_notification_config(client: Client, handler: mock.AsyncMock): handler.on_set_task_push_notification_config.assert_awaited_once() -def test_get_push_notification_config(client: Client, handler: mock.AsyncMock): +def test_get_push_notification_config(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test getting push notification configuration.""" # Setup mock response task_push_config = TaskPushNotificationConfig( - taskId="task1", - pushNotificationConfig=PushNotificationConfig(url="https://example.com", token="secret-token"), + task_id="task1", + push_notification_config=PushNotificationConfig(url="https://example.com", token="secret-token"), ) handler.on_get_task_push_notification_config.return_value = task_push_config @@ -560,7 +593,7 @@ def test_get_push_notification_config(client: Client, handler: mock.AsyncMock): handler.on_get_task_push_notification_config.assert_awaited_once() -def test_server_auth(create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock): +def test_server_auth(create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock, ensure_mock_task): class TestAuthMiddleware(AuthenticationBackend): async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: # For the purposes of this test, all requests are authenticated! @@ -572,8 +605,8 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas # Set the output message to be the authenticated user name handler.on_message_send.side_effect = lambda params, context: Message( - contextId="session-xyz", - messageId="112", + context_id="session-xyz", + message_id="112", role=Role.agent, parts=[ Part(TextPart(text=context.user.user_name)), @@ -616,7 +649,9 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas # === STREAMING TESTS === -async def test_message_send_stream(create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock) -> None: +async def test_message_send_stream( + create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock, ensure_mock_task +) -> None: """Test streaming message sending.""" # Setup mock streaming response @@ -625,14 +660,14 @@ async def stream_generator(): text_part = TextPart(**TEXT_PART_DATA) data_part = DataPart(**DATA_PART_DATA) artifact = Artifact( - artifactId=f"artifact-{i}", + artifact_id=f"artifact-{i}", name="result_data", parts=[Part(root=text_part), Part(root=data_part)], ) last = [False, False, True] task_artifact_update_event_data: dict[str, Any] = { "artifact": artifact, - "taskId": "task_id", + "taskId": "task1", "contextId": "session-xyz", "append": False, "lastChunk": last[i], @@ -661,7 +696,7 @@ async def stream_generator(): "parts": [{"kind": "text", "text": "Hello"}], "messageId": "111", "kind": "message", - "taskId": "taskId", + "taskId": "task1", "contextId": "session-xyz", } }, @@ -693,7 +728,9 @@ async def stream_generator(): await asyncio.sleep(0.1) -async def test_task_resubscription(create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock) -> None: +async def test_task_resubscription( + create_test_server, app: A2AStarletteApplication, handler: mock.AsyncMock, ensure_mock_task +) -> None: """Test task resubscription streaming.""" # Setup mock streaming response @@ -702,14 +739,14 @@ async def stream_generator(): text_part = TextPart(**TEXT_PART_DATA) data_part = DataPart(**DATA_PART_DATA) artifact = Artifact( - artifactId=f"artifact-{i}", + artifact_id=f"artifact-{i}", name="result_data", parts=[Part(root=text_part), Part(root=data_part)], ) last = [False, False, True] task_artifact_update_event_data: dict[str, Any] = { "artifact": artifact, - "taskId": "task_id", + "taskId": "task1", "contextId": "session-xyz", "append": False, "lastChunk": last[i], @@ -792,7 +829,7 @@ def test_invalid_request_structure(client: Client): assert data["error"]["code"] == InvalidRequestError().code -def test_method_not_implemented(client: Client, handler: mock.AsyncMock): +def test_method_not_implemented(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test handling MethodNotImplementedError.""" handler.on_get_task.side_effect = MethodNotImplementedError() @@ -852,7 +889,7 @@ def test_validation_error(client: Client): assert data["error"]["code"] == InvalidParamsError().code -def test_unhandled_exception(client: Client, handler: mock.AsyncMock): +def test_unhandled_exception(client: Client, handler: mock.AsyncMock, ensure_mock_task): """Test handling unhandled exception.""" handler.on_get_task.side_effect = Exception("Unexpected error") @@ -886,3 +923,275 @@ def test_non_dict_json(client: Client): data = response.json() assert "error" in data assert data["error"]["code"] == InvalidRequestError().code + + +# ------------------------------------- TESTS SPECIFIC TO PLATFORM PERMISSIONS ----------------------------------------- + + +def test_task_ownership_different_user_cannot_access_task(client: Client, handler: mock.AsyncMock, ensure_mock_task): + """Test that a task owned by admin cannot be accessed by default user.""" + # Task is already created by ensure_mock_task for admin user + + # Setup mock response + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + task = Task(id="task1", context_id="ctx1", state="completed", status=task_status) + handler.on_get_task.return_value = task + + # Try to access as default user (without auth) + client.auth = None + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "123", "method": "tasks/get", "params": {"id": "task1"}}, + ) + + # Should fail with error (forbidden or not found) + assert response.status_code == 200 + data = response.json() + assert data["error"]["code"] in [TaskNotFoundError().code] + + # Now try as admin user (who owns it) + client.auth = ("admin", "test-password") + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "123", "method": "tasks/get", "params": {"id": "task1"}}, + ) + + # Should succeed + assert response.status_code == 200 + data = response.json() + assert "result" in data + assert data["result"]["id"] == "task1" + + +async def test_task_ownership_new_task_creation_via_message_send( + client: Client, handler: mock.AsyncMock, db_transaction +): + """Test that sending a message creates a new task owned by the user.""" + # Setup mock response - server returns a new task + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + mock_task = Task( + id="new-task-123", + context_id="session-xyz", + status=task_status, + ) + handler.on_message_send.return_value = mock_task + + # Send message as admin which should create new task ownership + client.auth = ("admin", "test-password") + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "123", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "111", + "kind": "message", + "contextId": "session-xyz", + } + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == "new-task-123" + + # Verify task was recorded in database for admin user + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "new-task-123"}, + ) + row = result.fetchone() + assert row is not None + assert row.task_id == "new-task-123" + + # Verify we can access it as admin + task = Task(id="new-task-123", context_id="ctx1", state="completed", status=task_status) + handler.on_get_task.return_value = task + + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "124", "method": "tasks/get", "params": {"id": "new-task-123"}}, + ) + + assert response.status_code == 200 + assert response.json()["result"]["id"] == "new-task-123" + + # Verify default user cannot access it + client.auth = None + response = client.post( + "/", + json={"jsonrpc": "2.0", "id": "125", "method": "tasks/get", "params": {"id": "new-task-123"}}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["error"]["code"] in [TaskNotFoundError().code] + + +async def test_context_ownership_cannot_be_claimed_by_different_user( + client: Client, handler: mock.AsyncMock, db_transaction +): + """Test that a context_id owned by one user cannot be used by another.""" + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + + # Admin creates a message with a specific context + client.auth = ("admin", "test-password") + mock_task = Task(id="task-ctx-1", context_id="shared-context-789", status=task_status) + handler.on_message_send.return_value = mock_task + + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "123", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "111", + "kind": "message", + "contextId": "shared-context-789", + } + }, + }, + ) + + assert response.status_code == 200 + + # Verify context was recorded for admin + context_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "shared-context-789"}, + ) + context_row = context_result.fetchone() + assert context_row is not None + + # Now default user tries to use the same context - should fail + client.auth = None + mock_task2 = Task( + id="task-ctx-2", + context_id="shared-context-789", # Same context! + status=task_status, + ) + handler.on_message_send.return_value = mock_task2 + + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "124", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "112", + "kind": "message", + "contextId": "shared-context-789", + } + }, + }, + ) + + # Should fail + assert response.status_code == 200 + data = response.json() + assert data["error"]["code"] == InvalidRequestError().code + assert "insufficient permissions" in data["error"]["message"].lower() + + +async def test_task_update_last_accessed_at(client: Client, handler: mock.AsyncMock, db_transaction): + """Test that accessing a task updates last_accessed_at timestamp.""" + client.auth = ("admin", "test-password") + + mock_task = Task(id="task1", context_id="shared-context-789", status=TaskStatus(state=TaskState.submitted)) + handler.on_message_send.return_value = mock_task + message_data = { + "jsonrpc": "2.0", + "id": "123", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "111", + "kind": "message", + "contextId": "shared-context-789", + } + }, + } + + response = client.post("/", json=message_data) + # Get initial timestamp + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_tasks WHERE task_id = :task_id"), {"task_id": "task1"} + ) + initial_timestamp = result.fetchone().last_accessed_at + + # Wait a bit to ensure timestamp difference + await asyncio.sleep(0.1) + + # Access the task + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + task = Task(id="task1", context_id="ctx1", state="completed", status=task_status) + handler.on_get_task.return_value = task + + response = client.post("/", json=message_data) + assert response.status_code == 200 + + # Check that timestamp was updated + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "task1"}, + ) + new_timestamp = result.fetchone().last_accessed_at + assert new_timestamp > initial_timestamp + + +async def test_task_and_context_both_specified_single_query(client: Client, handler: mock.AsyncMock, db_transaction): + """Test that both task_id and context_id are tracked in a single query when both are specified.""" + client.auth = ("admin", "test-password") + + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + mock_task = Task(id="dual-task-123", context_id="dual-context-456", status=task_status) + handler.on_message_send.return_value = mock_task + + message_data = { + "jsonrpc": "2.0", + "id": "123", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "111", + "kind": "message", + "contextId": "dual-context-456", + } + }, + } + response = client.post("/", json=message_data) + assert response.status_code == 200 + message_data["params"]["message"]["taskId"] = "dual-task-123" + + response = client.post("/", json=message_data) + assert response.status_code == 200 + + # Verify both were recorded in database + task_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "dual-task-123"}, + ) + assert task_result.fetchone() is not None + + context_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "dual-context-456"}, + ) + assert context_result.fetchone() is not None diff --git a/apps/beeai-server/tests/integration/persistence/repositories/test_a2a_request.py b/apps/beeai-server/tests/integration/persistence/repositories/test_a2a_request.py new file mode 100644 index 000000000..474a343a0 --- /dev/null +++ b/apps/beeai-server/tests/integration/persistence/repositories/test_a2a_request.py @@ -0,0 +1,591 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +import uuid +from datetime import timedelta +from uuid import UUID + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from beeai_server.domain.models.a2a_request import A2ARequestTask +from beeai_server.exceptions import EntityNotFoundError, ForbiddenUpdateError +from beeai_server.infrastructure.persistence.repositories.requests import SqlAlchemyA2ARequestRepository +from beeai_server.utils.utils import utc_now + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def user1_id() -> UUID: + """First test user ID.""" + return uuid.uuid4() + + +@pytest.fixture +def user2_id() -> UUID: + """Second test user ID.""" + return uuid.uuid4() + + +@pytest.fixture +def provider_id() -> UUID: + """Test provider ID.""" + return uuid.uuid4() + + +# ================================ track_request_ids_ownership tests ================================ + + +async def test_track_new_task_with_creation_allowed(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test creating a new task when allow_task_creation=True.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Track a new task with creation allowed + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id="new-task-1", + allow_task_creation=True, + ) + + # Verify task was created in database + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "new-task-1"}, + ) + row = result.fetchone() + assert row is not None + assert row.task_id == "new-task-1" + assert row.created_by == user1_id + assert row.provider_id == provider_id + assert row.created_at is not None + assert row.last_accessed_at is not None + + +async def test_track_new_task_with_creation_not_allowed( + db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID +): + """Test that new task creation fails when allow_task_creation=False.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Attempt to track a new task with creation not allowed + with pytest.raises(EntityNotFoundError) as exc_info: + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id="nonexistent-task", + allow_task_creation=False, + ) + + assert exc_info.value.entity == "a2a_request_task" + assert exc_info.value.id == "nonexistent-task" + + # Verify task was NOT created in database + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "nonexistent-task"}, + ) + assert result.fetchone() is None + + +async def test_track_existing_task_owned_by_same_user( + db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID +): + """Test accessing an existing task owned by the same user updates last_accessed_at.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a task with an old timestamp + old_time = utc_now() - timedelta(hours=1) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :old_time, :old_time)" + ), + {"task_id": "existing-task", "created_by": user1_id, "provider_id": provider_id, "old_time": old_time}, + ) + + # Get initial timestamp + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "existing-task"}, + ) + initial_timestamp = result.fetchone().last_accessed_at + + # Track the existing task (this should update last_accessed_at to NOW()) + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id="existing-task", + allow_task_creation=False, + ) + + # Verify last_accessed_at was updated + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "existing-task"}, + ) + new_timestamp = result.fetchone().last_accessed_at + assert new_timestamp > initial_timestamp + + +async def test_track_existing_task_owned_by_different_user( + db_transaction: AsyncConnection, user1_id: UUID, user2_id: UUID, provider_id: UUID +): + """Test that accessing a task owned by a different user fails.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a task owned by user1 + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :now, :now)" + ), + {"task_id": "user1-task", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Try to access as user2 + with pytest.raises(EntityNotFoundError) as exc_info: + await repository.track_request_ids_ownership( + user_id=user2_id, + provider_id=provider_id, + task_id="user1-task", + allow_task_creation=False, + ) + + assert exc_info.value.entity == "a2a_request_task" + assert exc_info.value.id == "user1-task" + + +async def test_track_new_context(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test creating a new context.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Track a new context + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + context_id="new-context-1", + ) + + # Verify context was created in database + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "new-context-1"}, + ) + row = result.fetchone() + assert row is not None + assert row.context_id == "new-context-1" + assert row.created_by == user1_id + assert row.provider_id == provider_id + assert row.created_at is not None + assert row.last_accessed_at is not None + + +async def test_track_existing_context_owned_by_same_user( + db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID +): + """Test accessing an existing context owned by the same user updates last_accessed_at.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a context with an old timestamp + old_time = utc_now() - timedelta(hours=1) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :old_time, :old_time)" + ), + {"context_id": "existing-context", "created_by": user1_id, "provider_id": provider_id, "old_time": old_time}, + ) + + # Get initial timestamp + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "existing-context"}, + ) + initial_timestamp = result.fetchone().last_accessed_at + + # Track the existing context (this should update last_accessed_at to NOW()) + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + context_id="existing-context", + ) + + # Verify last_accessed_at was updated + result = await db_transaction.execute( + text("SELECT last_accessed_at FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "existing-context"}, + ) + new_timestamp = result.fetchone().last_accessed_at + assert new_timestamp > initial_timestamp + + +async def test_track_existing_context_owned_by_different_user( + db_transaction: AsyncConnection, user1_id: UUID, user2_id: UUID, provider_id: UUID +): + """Test that accessing a context owned by a different user fails.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a context owned by user1 + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :now, :now)" + ), + {"context_id": "user1-context", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Try to access as user2 + with pytest.raises(ForbiddenUpdateError) as exc_info: + await repository.track_request_ids_ownership( + user_id=user2_id, + provider_id=provider_id, + context_id="user1-context", + ) + + assert exc_info.value.entity == "a2a_request_context" + assert exc_info.value.id == "user1-context" + + +async def test_track_both_task_and_context(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test tracking both task_id and context_id in a single call.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Track both new task and context + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id="both-task-1", + context_id="both-context-1", + allow_task_creation=True, + ) + + # Verify both were created + task_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "both-task-1"}, + ) + assert task_result.fetchone() is not None + + context_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "both-context-1"}, + ) + assert context_result.fetchone() is not None + + +async def test_track_both_task_and_context_task_owned_by_different_user( + db_transaction: AsyncConnection, user1_id: UUID, user2_id: UUID, provider_id: UUID +): + """Test that when task is owned by different user, the operation fails even if context would succeed.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a task owned by user1 + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :now, :now)" + ), + {"task_id": "user1-task-2", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Try to track with user2 (task fails, context would succeed) + with pytest.raises(EntityNotFoundError) as exc_info: + await repository.track_request_ids_ownership( + user_id=user2_id, + provider_id=provider_id, + task_id="user1-task-2", + context_id="new-context-2", + allow_task_creation=False, + ) + + assert exc_info.value.entity == "a2a_request_task" + assert exc_info.value.id == "user1-task-2" + + +async def test_track_both_task_and_context_context_owned_by_different_user( + db_transaction: AsyncConnection, user1_id: UUID, user2_id: UUID, provider_id: UUID +): + """Test that when context is owned by different user, the operation fails even if task would succeed.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a context owned by user1 + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :now, :now)" + ), + {"context_id": "user1-context-2", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Try to track with user2 (context fails, task would succeed) + with pytest.raises(ForbiddenUpdateError) as exc_info: + await repository.track_request_ids_ownership( + user_id=user2_id, + provider_id=provider_id, + task_id="new-task-3", + context_id="user1-context-2", + allow_task_creation=True, + ) + + assert exc_info.value.entity == "a2a_request_context" + assert exc_info.value.id == "user1-context-2" + + +async def test_track_null_task_id_succeeds(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test that passing None as task_id succeeds (returns task_authorized=true).""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Should not raise any exception + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id=None, + context_id="context-only", + ) + + # Verify only context was created + context_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "context-only"}, + ) + assert context_result.fetchone() is not None + + +async def test_track_null_context_id_succeeds(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test that passing None as context_id succeeds (returns context_authorized=true).""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Should not raise any exception + await repository.track_request_ids_ownership( + user_id=user1_id, + provider_id=provider_id, + task_id="task-only", + context_id=None, + allow_task_creation=True, + ) + + # Verify only task was created + task_result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "task-only"}, + ) + assert task_result.fetchone() is not None + + +# ================================ get_task tests ================================ + + +async def test_get_task_success(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test getting a task that exists and is owned by the user.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a task + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :now, :now)" + ), + {"task_id": "get-task-1", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Get the task + task = await repository.get_task(task_id="get-task-1", user_id=user1_id) + + # Verify task + assert isinstance(task, A2ARequestTask) + assert task.task_id == "get-task-1" + assert task.created_by == user1_id + assert task.provider_id == provider_id + assert task.created_at is not None + assert task.last_accessed_at is not None + + +async def test_get_task_not_found(db_transaction: AsyncConnection, user1_id: UUID): + """Test getting a task that doesn't exist.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Try to get non-existent task + with pytest.raises(EntityNotFoundError) as exc_info: + await repository.get_task(task_id="nonexistent-task", user_id=user1_id) + + assert exc_info.value.entity == "a2a_request_task" + assert exc_info.value.id == "nonexistent-task" + + +async def test_get_task_owned_by_different_user( + db_transaction: AsyncConnection, user1_id: UUID, user2_id: UUID, provider_id: UUID +): + """Test getting a task that exists but is owned by a different user.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a task owned by user1 + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :now, :now)" + ), + {"task_id": "user1-get-task", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Try to get as user2 + with pytest.raises(EntityNotFoundError) as exc_info: + await repository.get_task(task_id="user1-get-task", user_id=user2_id) + + assert exc_info.value.entity == "a2a_request_task" + assert exc_info.value.id == "user1-get-task" + + +# ================================ delete_tasks tests ================================ + + +async def test_delete_tasks_older_than_timedelta(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test deleting tasks older than a specified timedelta.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create an old task (5 days old) + old_time = utc_now() - timedelta(days=5) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :old_time, :old_time)" + ), + {"task_id": "old-task", "created_by": user1_id, "provider_id": provider_id, "old_time": old_time}, + ) + + # Create a recent task (1 hour old) + recent_time = utc_now() - timedelta(hours=1) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :recent_time, :recent_time)" + ), + {"task_id": "recent-task", "created_by": user1_id, "provider_id": provider_id, "recent_time": recent_time}, + ) + + # Delete tasks older than 3 days + await repository.delete_tasks(older_than=timedelta(days=3)) + + # Verify old task was deleted + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "old-task"}, + ) + assert result.fetchone() is None + + # Verify recent task still exists + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "recent-task"}, + ) + assert result.fetchone() is not None + + +async def test_delete_tasks_no_tasks_to_delete(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test deleting tasks when no tasks match the criteria.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a recent task + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_tasks (task_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:task_id, :created_by, :provider_id, :now, :now)" + ), + {"task_id": "recent-task-2", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Delete tasks older than 1 day (should delete nothing) + await repository.delete_tasks(older_than=timedelta(days=1)) + + # Verify task still exists + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_tasks WHERE task_id = :task_id"), + {"task_id": "recent-task-2"}, + ) + assert result.fetchone() is not None + + +# ================================ delete_contexts tests ================================ + + +async def test_delete_contexts_older_than_timedelta(db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID): + """Test deleting contexts older than a specified timedelta.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create an old context (5 days old) + old_time = utc_now() - timedelta(days=5) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :old_time, :old_time)" + ), + {"context_id": "old-context", "created_by": user1_id, "provider_id": provider_id, "old_time": old_time}, + ) + + # Create a recent context (1 hour old) + recent_time = utc_now() - timedelta(hours=1) + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :recent_time, :recent_time)" + ), + { + "context_id": "recent-context", + "created_by": user1_id, + "provider_id": provider_id, + "recent_time": recent_time, + }, + ) + + # Delete contexts older than 3 days + await repository.delete_contexts(older_than=timedelta(days=3)) + + # Verify old context was deleted + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "old-context"}, + ) + assert result.fetchone() is None + + # Verify recent context still exists + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "recent-context"}, + ) + assert result.fetchone() is not None + + +async def test_delete_contexts_no_contexts_to_delete( + db_transaction: AsyncConnection, user1_id: UUID, provider_id: UUID +): + """Test deleting contexts when no contexts match the criteria.""" + repository = SqlAlchemyA2ARequestRepository(connection=db_transaction) + + # Create a recent context + now = utc_now() + await db_transaction.execute( + text( + "INSERT INTO a2a_request_contexts (context_id, created_by, provider_id, created_at, last_accessed_at) " + "VALUES (:context_id, :created_by, :provider_id, :now, :now)" + ), + {"context_id": "recent-context-2", "created_by": user1_id, "provider_id": provider_id, "now": now}, + ) + + # Delete contexts older than 1 day (should delete nothing) + await repository.delete_contexts(older_than=timedelta(days=1)) + + # Verify context still exists + result = await db_transaction.execute( + text("SELECT * FROM a2a_request_contexts WHERE context_id = :context_id"), + {"context_id": "recent-context-2"}, + ) + assert result.fetchone() is not None diff --git a/apps/beeai-server/uv.lock b/apps/beeai-server/uv.lock index 2709b556f..41fa28c3a 100644 --- a/apps/beeai-server/uv.lock +++ b/apps/beeai-server/uv.lock @@ -4,7 +4,7 @@ requires-python = "==3.12.*" [[package]] name = "a2a-sdk" -version = "0.3.7" +version = "0.3.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -13,9 +13,9 @@ dependencies = [ { name = "protobuf" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/ad/b6ecb58f44459a24f1c260e91304e1ddbb7a8e213f1f82cc4c074f66e9bb/a2a_sdk-0.3.7.tar.gz", hash = "sha256:795aa2bd2cfb3c9e8654a1352bf5f75d6cf1205b262b1bf8f4003b5308267ea2", size = 223426, upload-time = "2025-09-23T16:27:29.585Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/0b/80671e784f61b55ac4c340d125d121ba91eba58ad7ba0f03b53b3831cd32/a2a_sdk-0.3.9.tar.gz", hash = "sha256:1dff7b5b1cab0b221519d0faed50176e200a1a87a8de8b64308d876505cc7c77", size = 224528, upload-time = "2025-10-15T17:35:28.299Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/27/9cf8c6de4ae71e9c98ec96b3304449d5d0cd36ec3b95e66b6e7f58a9e571/a2a_sdk-0.3.7-py3-none-any.whl", hash = "sha256:0813b8fd7add427b2b56895cf28cae705303cf6d671b305c0aac69987816e03e", size = 137957, upload-time = "2025-09-23T16:27:27.546Z" }, + { url = "https://files.pythonhosted.org/packages/34/ee/53b2da6d2768b136f996b8c6ab00ebcc44852f9a33816a64deaca6b279fe/a2a_sdk-0.3.9-py3-none-any.whl", hash = "sha256:7ed03a915bae98def46ea0313786da0a7a488346c3dc8af88407bb0b2a763926", size = 139027, upload-time = "2025-10-15T17:35:26.628Z" }, ] [[package]] @@ -281,7 +281,7 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", specifier = "==0.3.7" }, + { name = "a2a-sdk", specifier = "==0.3.9" }, { name = "anyio", specifier = ">=4.9.0" }, { name = "asyncclick", specifier = ">=8.1.8" }, { name = "fastapi", specifier = ">=0.116.1" }, @@ -344,6 +344,8 @@ dependencies = [ { name = "python-multipart" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlparse" }, + { name = "sse-starlette" }, + { name = "starlette" }, { name = "structlog" }, { name = "tenacity" }, { name = "uvicorn" }, @@ -364,7 +366,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", specifier = "~=0.3.5" }, + { name = "a2a-sdk", specifier = "~=0.3.9" }, { name = "aioboto3", specifier = ">=14.3.0" }, { name = "aiodocker", specifier = ">=0.24.0" }, { name = "aiohttp", specifier = ">=3.11.11" }, @@ -394,6 +396,8 @@ requires-dist = [ { name = "python-multipart", specifier = ">=0.0.20" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlparse", specifier = ">=0.5.3" }, + { name = "sse-starlette", specifier = ">=3.0.2" }, + { name = "starlette", specifier = ">=0.48.0" }, { name = "structlog", specifier = ">=25.1.0" }, { name = "tenacity", specifier = ">=9.0.0" }, { name = "uvicorn", specifier = ">=0.34.0" }, @@ -615,16 +619,16 @@ wheels = [ [[package]] name = "fastapi" -version = "0.116.1" +version = "0.119.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/f4/152127681182e6413e7a89684c434e19e7414ed7ac0c632999c3c6980640/fastapi-0.119.1.tar.gz", hash = "sha256:a5e3426edce3fe221af4e1992c6d79011b247e3b03cc57999d697fe76cbf8ae0", size = 338616, upload-time = "2025-10-20T11:30:27.734Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, + { url = "https://files.pythonhosted.org/packages/b1/26/e6d959b4ac959fdb3e9c4154656fc160794db6af8e64673d52759456bf07/fastapi-0.119.1-py3-none-any.whl", hash = "sha256:0b8c2a2cce853216e150e9bd4faaed88227f8eb37de21cb200771f491586a27f", size = 108123, upload-time = "2025-10-20T11:30:26.185Z" }, ] [package.optional-dependencies] @@ -2061,15 +2065,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.47.3" +version = "0.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/15/b9/cc3017f9a9c9b6e27c5106cc10cc7904653c3eec0729793aec10479dd669/starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9", size = 2584144, upload-time = "2025-08-24T13:36:42.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" }, + { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, ] [[package]] diff --git a/helm/templates/deployment.yaml b/helm/templates/deployment.yaml index a4bd32e54..798091709 100644 --- a/helm/templates/deployment.yaml +++ b/helm/templates/deployment.yaml @@ -181,6 +181,8 @@ spec: {{/* see https://fastapi.tiangolo.com/deployment/https/?h=forwarded+allow+ips#proxy-forwarded-headers */}} - name: TRUST_PROXY_HEADERS value: {{ .Values.trustProxyHeaders | quote }} + - name: A2A_PROXY__REQUESTS_EXPIRE_AFTER_DAYS + value: {{ .Values.a2aProxyRequestsExpireAfterDays | quote }} - name: TEXT_EXTRACTION__ENABLED value: {{ .Values.docling.enabled | quote }} - name: AGENT_REGISTRY__SYNC_PERIOD_SEC diff --git a/helm/values.yaml b/helm/values.yaml index 6f3af4c84..a6d32336d 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -40,6 +40,13 @@ externalRegistries: { } # ------- SECURITY --------- +# How long task_id and context_id ownership will be tracked +# Performance considerations: all context IDs are recorded for this period +# Security Considerations: +# Agents should clear all local state stored under context_id within this time period +# No security implication for task_id (tasks ids can be generated only server-side in the agent) +a2aProxyRequestsExpireAfterDays: 14 + encryptionKey: "" auth: enabled: false # Warning, disable only for local deployments diff --git a/tasks.toml b/tasks.toml index fa4e423be..13998db6a 100644 --- a/tasks.toml +++ b/tasks.toml @@ -132,11 +132,11 @@ run = """ EXCEPT='{{option(name="except", default="")}}' {% raw %} -TO_DELETE="$(LIMA_HOME=~/.beeai/lima limactl list -f '{{.Name}}' 2>/dev/null | sed '/^[^a-z]*$/d' | sed "/^$EXCEPT$/d")" +TO_STOP="$(LIMA_HOME=~/.beeai/lima limactl list -f '{{.Name}};{{.Status}}' | grep -v "Stopped" | cut -d';' -f1 2>/dev/null | sed '/^[^a-z]*$/d' | sed "/^$EXCEPT$/d")" {% endraw %} {% raw %} -echo "$TO_DELETE" | xargs -rn 1 -I"{}" mise run beeai-cli:run -- platform stop --vm-name="{}" +echo "$TO_STOP" | xargs -rn 1 -I"{}" mise run beeai-cli:run -- platform stop --vm-name="{}" {% endraw %} """