Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions apps/agentstack-cli/src/agentstack_cli/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import json
import re
import urllib
import urllib.parse
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import timedelta
from textwrap import indent
from typing import Any

import httpx
import openai
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
from a2a.types import AgentCard
from httpx import HTTPStatusError
from httpx._types import RequestFiles
Expand Down Expand Up @@ -103,16 +104,29 @@ async def api_stream(

@asynccontextmanager
async def a2a_client(agent_card: AgentCard, use_auth: bool = True) -> AsyncIterator[Client]:
async with httpx.AsyncClient(
headers=(
{"Authorization": f"Bearer {token}"}
if use_auth and (token := config.auth_manager.load_auth_token())
else {}
),
follow_redirects=True,
timeout=timedelta(hours=1).total_seconds(),
) as httpx_client:
yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(card=agent_card)
try:
async with httpx.AsyncClient(
headers=(
{"Authorization": f"Bearer {token}"}
if use_auth and (token := config.auth_manager.load_auth_token())
else {}
),
follow_redirects=True,
timeout=timedelta(hours=1).total_seconds(),
) as httpx_client:
yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(
card=agent_card
)
except A2AClientHTTPError as ex:
card_data = json.dumps(
agent_card.model_dump(include={"url", "additional_interfaces", "preferred_transport"}), indent=2
)
raise RuntimeError(
f"The agent is not reachable, please check that the agent card is configured properly.\n"
f"Agent connection info:\n{indent(card_data, prefix=' ')}\n"
"Full Error:\n"
f"{indent(str(ex), prefix=' ')}"
) from ex


@asynccontextmanager
Expand Down
3 changes: 3 additions & 0 deletions apps/agentstack-server/src/agentstack_server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(
super().__init__("Insufficient permissions", status_code)


class InvalidProviderCallError(PlatformError): ...


class InvalidVectorDimensionError(PlatformError): ...


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import inspect
import logging
import uuid
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable
from contextlib import asynccontextmanager
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
from contextlib import asynccontextmanager, contextmanager
from datetime import timedelta
from typing import NamedTuple, cast
from urllib.parse import urljoin, urlparse
Expand All @@ -24,6 +24,7 @@
AgentCard,
DeleteTaskPushNotificationConfigParams,
GetTaskPushNotificationConfigParams,
InternalError,
InvalidRequestError,
ListTaskPushNotificationConfigParams,
Message,
Expand All @@ -47,9 +48,10 @@
NetworkProviderLocation,
Provider,
ProviderDeploymentState,
UnmanagedState,
)
from agentstack_server.domain.models.user import User
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError, InvalidProviderCallError
from agentstack_server.service_layer.deployment_manager import (
IProviderDeploymentManager,
)
Expand Down Expand Up @@ -104,10 +106,10 @@ class A2AServerResponse(NamedTuple):


def _handle_exception[T: Callable](fn: T) -> T:
@functools.wraps(fn)
async def _fn(*args, **kwargs):
@contextmanager
def _handle_exception_impl() -> Iterator[None]:
try:
return await fn(*args, **kwargs)
yield
except EntityNotFoundError as e:
if "task" in e.entity:
raise ServerError(error=TaskNotFoundError()) from e
Expand All @@ -116,33 +118,50 @@ async def _fn(*args, **kwargs):
raise ServerError(error=InvalidRequestError(message=str(e))) from e
except A2AClientJSONRPCError as e:
raise ServerError(error=e.error) from e
except InvalidProviderCallError as e:
raise ServerError(error=InvalidRequestError(message=f"Invalid request to agent: {e!r}")) from e
except Exception as e:
raise ServerError(error=InternalError(message=f"Internal error: {e!r}")) from e

@functools.wraps(fn)
async def _fn(*args, **kwargs):
with _handle_exception_impl():
return await fn(*args, **kwargs)

@functools.wraps(fn)
async def _fn_iter(*args, **kwargs):
try:
with _handle_exception_impl():
async for item in fn(*args, **kwargs):
yield item
except A2AClientJSONRPCError as e:
raise ServerError(error=e.error) from e

return _fn_iter if inspect.isasyncgenfunction(fn) else _fn # pyright: ignore [reportReturnType]


class ProxyRequestHandler(RequestHandler):
def __init__(
self,
agent_card: AgentCard,
*,
provider_id: UUID,
uow: IUnitOfWorkFactory,
user: User,
# Calling the factory have side-effects, such as rotating the agent
agent_card_factory: Callable[[], Awaitable[AgentCard]] | None = None,
agent_card: AgentCard | None = None,
):
if agent_card_factory is None and agent_card is None:
raise ValueError("One of agent_card_factory or agent_card must be provided")
self._agent_card_factory = agent_card_factory
self._agent_card = agent_card
self._provider_id = provider_id
self._user = user
self._uow = uow

@asynccontextmanager
async def _client_transport(self) -> AsyncIterator[ClientTransport]:
if self._agent_card is None:
assert self._agent_card_factory is not None
self._agent_card = await self._agent_card_factory()

async with httpx.AsyncClient(follow_redirects=True, timeout=timedelta(hours=1).total_seconds()) as httpx_client:
client: BaseClient = cast(
BaseClient,
Expand Down Expand Up @@ -306,10 +325,13 @@ def __init__(
self._expire_requests_after = timedelta(days=configuration.a2a_proxy.requests_expire_after_days)

async def get_request_handler(self, *, provider: Provider, user: User) -> RequestHandler:
url = await self.ensure_agent(provider_id=provider.id)
agent_card = create_deployment_agent_card(provider.agent_card, deployment_base=str(url))
async def agent_card_factory() -> AgentCard:
# Delay ensure_agent to the handler so that errors are wrapped properly
url = await self.ensure_agent(provider_id=provider.id)
return create_deployment_agent_card(provider.agent_card, deployment_base=str(url))

return ProxyRequestHandler(
agent_card=agent_card,
agent_card_factory=agent_card_factory,
provider_id=provider.id,
uow=self._uow,
user=user,
Expand All @@ -332,6 +354,11 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
await uow.commit()

if not provider.managed:
if provider.unmanaged_state is UnmanagedState.OFFLINE:
raise InvalidProviderCallError(
f"Cannot send message to provider {provider_id}: provider is offline"
)

assert isinstance(provider.source, NetworkProviderLocation)
return provider.source.a2a_url

Expand All @@ -340,7 +367,9 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
should_wait = False
match state:
case ProviderDeploymentState.ERROR:
raise RuntimeError("Provider is in an error state")
raise InvalidProviderCallError(
f"Cannot send message to provider {provider_id}: provider is in an error state"
)
case (
ProviderDeploymentState.MISSING
| ProviderDeploymentState.RUNNING
Expand Down
2 changes: 1 addition & 1 deletion apps/agentstack-server/template.env
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ OCI_REGISTRY__AGENTSTACK-REGISTRY-SVC.DEFAULT:5001__INSECURE=true
# AUTH__BASIC__ENABLED=true
# AUTH__BASIC__ADMIN_PASSWORD=test-password
# AUTH__BASIC__ADMIN_PASSWORD=test-password
# AUTH__BASIC__JWT_SECRET_KEY=test-key
# AUTH__JWT_SECRET_KEY=test-key

# AUTH__OIDC__ENABLED=true
# AUTH__OIDC__ADMIN_EMAILS='["[email protected]"]'
Expand Down
61 changes: 61 additions & 0 deletions apps/agentstack-server/tests/e2e/routes/test_a2a_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,34 @@ def test_task_ownership_different_user_cannot_access_task(client: Client, handle
assert data["result"]["id"] == "task1"


async def test_unknown_task_raises_error(client: Client, handler: mock.AsyncMock, db_transaction):
"""Test that sending a message creates a new task owned by the user."""
# Send message with non-existing task
client.auth = ("admin", "test-password")
response = client.post(
"/",
json={
"jsonrpc": "2.0",
"id": "123",
"method": "message/send",
"params": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": "Hello"}],
"taskId": "unknown-task",
"messageId": "111",
"kind": "message",
"contextId": "session-xyz",
}
},
},
)

assert response.status_code == 200
data = response.json()
assert data["error"]["code"] in [TaskNotFoundError().code]


async def test_task_ownership_new_task_creation_via_message_send(
client: Client, handler: mock.AsyncMock, db_transaction
):
Expand Down Expand Up @@ -1195,3 +1223,36 @@ async def test_task_and_context_both_specified_single_query(client: Client, hand
{"context_id": "dual-context-456"},
)
assert context_result.fetchone() is not None


async def test_invalid_request_raises_a2a_error(client: Client, handler: mock.AsyncMock, db_transaction):
"""Test that an invalid request to an offline provider returns an A2A error."""

# set provider as offline
provider_id = str(client.base_url).rstrip("/").split("/")[-1]
await db_transaction.execute(
text("UPDATE providers SET unmanaged_state = 'offline' WHERE id = :provider_id"),
{"provider_id": provider_id},
)
await db_transaction.commit()

message_data = {
"jsonrpc": "2.0",
"id": "123",
"method": "message/send",
"params": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": "Hello"}],
"messageId": "111",
"kind": "message",
}
},
}
response = client.post("/", json=message_data)
assert response.status_code == 200
data = response.json()
assert data["id"] == "123"
assert "error" in data
assert data["error"]["code"] == InvalidRequestError().code
assert "provider is offline" in data["error"]["message"]
Loading