Skip to content

Commit 77f9b11

Browse files
jezekra1edengilbert
authored andcommitted
fix(server): send proxy errors as a2a errors (i-am-bee#1594)
Signed-off-by: Radek Ježek <[email protected]> Signed-off-by: Eden Gilbert <[email protected]>
1 parent 3131abb commit 77f9b11

File tree

5 files changed

+134
-27
lines changed

5 files changed

+134
-27
lines changed

apps/agentstack-cli/src/agentstack_cli/api.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
22
# SPDX-License-Identifier: Apache-2.0
3-
3+
import json
44
import re
55
import urllib
66
import urllib.parse
77
from collections.abc import AsyncIterator
88
from contextlib import asynccontextmanager
99
from datetime import timedelta
10+
from textwrap import indent
1011
from typing import Any
1112

1213
import httpx
1314
import openai
14-
from a2a.client import Client, ClientConfig, ClientFactory
15+
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
1516
from a2a.types import AgentCard
1617
from httpx import HTTPStatusError
1718
from httpx._types import RequestFiles
@@ -103,16 +104,29 @@ async def api_stream(
103104

104105
@asynccontextmanager
105106
async def a2a_client(agent_card: AgentCard, use_auth: bool = True) -> AsyncIterator[Client]:
106-
async with httpx.AsyncClient(
107-
headers=(
108-
{"Authorization": f"Bearer {token}"}
109-
if use_auth and (token := config.auth_manager.load_auth_token())
110-
else {}
111-
),
112-
follow_redirects=True,
113-
timeout=timedelta(hours=1).total_seconds(),
114-
) as httpx_client:
115-
yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(card=agent_card)
107+
try:
108+
async with httpx.AsyncClient(
109+
headers=(
110+
{"Authorization": f"Bearer {token}"}
111+
if use_auth and (token := config.auth_manager.load_auth_token())
112+
else {}
113+
),
114+
follow_redirects=True,
115+
timeout=timedelta(hours=1).total_seconds(),
116+
) as httpx_client:
117+
yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(
118+
card=agent_card
119+
)
120+
except A2AClientHTTPError as ex:
121+
card_data = json.dumps(
122+
agent_card.model_dump(include={"url", "additional_interfaces", "preferred_transport"}), indent=2
123+
)
124+
raise RuntimeError(
125+
f"The agent is not reachable, please check that the agent card is configured properly.\n"
126+
f"Agent connection info:\n{indent(card_data, prefix=' ')}\n"
127+
"Full Error:\n"
128+
f"{indent(str(ex), prefix=' ')}"
129+
) from ex
116130

117131

118132
@asynccontextmanager

apps/agentstack-server/src/agentstack_server/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def __init__(
6969
super().__init__("Insufficient permissions", status_code)
7070

7171

72+
class InvalidProviderCallError(PlatformError): ...
73+
74+
7275
class InvalidVectorDimensionError(PlatformError): ...
7376

7477

apps/agentstack-server/src/agentstack_server/service_layer/services/a2a.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import inspect
66
import logging
77
import uuid
8-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable
9-
from contextlib import asynccontextmanager
8+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator
9+
from contextlib import asynccontextmanager, contextmanager
1010
from datetime import timedelta
1111
from typing import NamedTuple, cast
1212
from urllib.parse import urljoin, urlparse
@@ -24,6 +24,7 @@
2424
AgentCard,
2525
DeleteTaskPushNotificationConfigParams,
2626
GetTaskPushNotificationConfigParams,
27+
InternalError,
2728
InvalidRequestError,
2829
ListTaskPushNotificationConfigParams,
2930
Message,
@@ -47,9 +48,10 @@
4748
NetworkProviderLocation,
4849
Provider,
4950
ProviderDeploymentState,
51+
UnmanagedState,
5052
)
5153
from agentstack_server.domain.models.user import User
52-
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError
54+
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError, InvalidProviderCallError
5355
from agentstack_server.service_layer.deployment_manager import (
5456
IProviderDeploymentManager,
5557
)
@@ -104,10 +106,10 @@ class A2AServerResponse(NamedTuple):
104106

105107

106108
def _handle_exception[T: Callable](fn: T) -> T:
107-
@functools.wraps(fn)
108-
async def _fn(*args, **kwargs):
109+
@contextmanager
110+
def _handle_exception_impl() -> Iterator[None]:
109111
try:
110-
return await fn(*args, **kwargs)
112+
yield
111113
except EntityNotFoundError as e:
112114
if "task" in e.entity:
113115
raise ServerError(error=TaskNotFoundError()) from e
@@ -116,33 +118,50 @@ async def _fn(*args, **kwargs):
116118
raise ServerError(error=InvalidRequestError(message=str(e))) from e
117119
except A2AClientJSONRPCError as e:
118120
raise ServerError(error=e.error) from e
121+
except InvalidProviderCallError as e:
122+
raise ServerError(error=InvalidRequestError(message=f"Invalid request to agent: {e!r}")) from e
123+
except Exception as e:
124+
raise ServerError(error=InternalError(message=f"Internal error: {e!r}")) from e
125+
126+
@functools.wraps(fn)
127+
async def _fn(*args, **kwargs):
128+
with _handle_exception_impl():
129+
return await fn(*args, **kwargs)
119130

120131
@functools.wraps(fn)
121132
async def _fn_iter(*args, **kwargs):
122-
try:
133+
with _handle_exception_impl():
123134
async for item in fn(*args, **kwargs):
124135
yield item
125-
except A2AClientJSONRPCError as e:
126-
raise ServerError(error=e.error) from e
127136

128137
return _fn_iter if inspect.isasyncgenfunction(fn) else _fn # pyright: ignore [reportReturnType]
129138

130139

131140
class ProxyRequestHandler(RequestHandler):
132141
def __init__(
133142
self,
134-
agent_card: AgentCard,
143+
*,
135144
provider_id: UUID,
136145
uow: IUnitOfWorkFactory,
137146
user: User,
147+
# Calling the factory have side-effects, such as rotating the agent
148+
agent_card_factory: Callable[[], Awaitable[AgentCard]] | None = None,
149+
agent_card: AgentCard | None = None,
138150
):
151+
if agent_card_factory is None and agent_card is None:
152+
raise ValueError("One of agent_card_factory or agent_card must be provided")
153+
self._agent_card_factory = agent_card_factory
139154
self._agent_card = agent_card
140155
self._provider_id = provider_id
141156
self._user = user
142157
self._uow = uow
143158

144159
@asynccontextmanager
145160
async def _client_transport(self) -> AsyncIterator[ClientTransport]:
161+
if self._agent_card is None:
162+
assert self._agent_card_factory is not None
163+
self._agent_card = await self._agent_card_factory()
164+
146165
async with httpx.AsyncClient(follow_redirects=True, timeout=timedelta(hours=1).total_seconds()) as httpx_client:
147166
client: BaseClient = cast(
148167
BaseClient,
@@ -306,10 +325,13 @@ def __init__(
306325
self._expire_requests_after = timedelta(days=configuration.a2a_proxy.requests_expire_after_days)
307326

308327
async def get_request_handler(self, *, provider: Provider, user: User) -> RequestHandler:
309-
url = await self.ensure_agent(provider_id=provider.id)
310-
agent_card = create_deployment_agent_card(provider.agent_card, deployment_base=str(url))
328+
async def agent_card_factory() -> AgentCard:
329+
# Delay ensure_agent to the handler so that errors are wrapped properly
330+
url = await self.ensure_agent(provider_id=provider.id)
331+
return create_deployment_agent_card(provider.agent_card, deployment_base=str(url))
332+
311333
return ProxyRequestHandler(
312-
agent_card=agent_card,
334+
agent_card_factory=agent_card_factory,
313335
provider_id=provider.id,
314336
uow=self._uow,
315337
user=user,
@@ -332,6 +354,11 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
332354
await uow.commit()
333355

334356
if not provider.managed:
357+
if provider.unmanaged_state is UnmanagedState.OFFLINE:
358+
raise InvalidProviderCallError(
359+
f"Cannot send message to provider {provider_id}: provider is offline"
360+
)
361+
335362
assert isinstance(provider.source, NetworkProviderLocation)
336363
return provider.source.a2a_url
337364

@@ -340,7 +367,9 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
340367
should_wait = False
341368
match state:
342369
case ProviderDeploymentState.ERROR:
343-
raise RuntimeError("Provider is in an error state")
370+
raise InvalidProviderCallError(
371+
f"Cannot send message to provider {provider_id}: provider is in an error state"
372+
)
344373
case (
345374
ProviderDeploymentState.MISSING
346375
| ProviderDeploymentState.RUNNING

apps/agentstack-server/template.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ OCI_REGISTRY__AGENTSTACK-REGISTRY-SVC.DEFAULT:5001__INSECURE=true
3636
# AUTH__BASIC__ENABLED=true
3737
# AUTH__BASIC__ADMIN_PASSWORD=test-password
3838
# AUTH__BASIC__ADMIN_PASSWORD=test-password
39-
# AUTH__BASIC__JWT_SECRET_KEY=test-key
39+
# AUTH__JWT_SECRET_KEY=test-key
4040

4141
# AUTH__OIDC__ENABLED=true
4242
# AUTH__OIDC__ADMIN_EMAILS='["[email protected]"]'

apps/agentstack-server/tests/e2e/routes/test_a2a_proxy.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,34 @@ def test_task_ownership_different_user_cannot_access_task(client: Client, handle
963963
assert data["result"]["id"] == "task1"
964964

965965

966+
async def test_unknown_task_raises_error(client: Client, handler: mock.AsyncMock, db_transaction):
967+
"""Test that sending a message creates a new task owned by the user."""
968+
# Send message with non-existing task
969+
client.auth = ("admin", "test-password")
970+
response = client.post(
971+
"/",
972+
json={
973+
"jsonrpc": "2.0",
974+
"id": "123",
975+
"method": "message/send",
976+
"params": {
977+
"message": {
978+
"role": "agent",
979+
"parts": [{"kind": "text", "text": "Hello"}],
980+
"taskId": "unknown-task",
981+
"messageId": "111",
982+
"kind": "message",
983+
"contextId": "session-xyz",
984+
}
985+
},
986+
},
987+
)
988+
989+
assert response.status_code == 200
990+
data = response.json()
991+
assert data["error"]["code"] in [TaskNotFoundError().code]
992+
993+
966994
async def test_task_ownership_new_task_creation_via_message_send(
967995
client: Client, handler: mock.AsyncMock, db_transaction
968996
):
@@ -1195,3 +1223,36 @@ async def test_task_and_context_both_specified_single_query(client: Client, hand
11951223
{"context_id": "dual-context-456"},
11961224
)
11971225
assert context_result.fetchone() is not None
1226+
1227+
1228+
async def test_invalid_request_raises_a2a_error(client: Client, handler: mock.AsyncMock, db_transaction):
1229+
"""Test that an invalid request to an offline provider returns an A2A error."""
1230+
1231+
# set provider as offline
1232+
provider_id = str(client.base_url).rstrip("/").split("/")[-1]
1233+
await db_transaction.execute(
1234+
text("UPDATE providers SET unmanaged_state = 'offline' WHERE id = :provider_id"),
1235+
{"provider_id": provider_id},
1236+
)
1237+
await db_transaction.commit()
1238+
1239+
message_data = {
1240+
"jsonrpc": "2.0",
1241+
"id": "123",
1242+
"method": "message/send",
1243+
"params": {
1244+
"message": {
1245+
"role": "agent",
1246+
"parts": [{"kind": "text", "text": "Hello"}],
1247+
"messageId": "111",
1248+
"kind": "message",
1249+
}
1250+
},
1251+
}
1252+
response = client.post("/", json=message_data)
1253+
assert response.status_code == 200
1254+
data = response.json()
1255+
assert data["id"] == "123"
1256+
assert "error" in data
1257+
assert data["error"]["code"] == InvalidRequestError().code
1258+
assert "provider is offline" in data["error"]["message"]

0 commit comments

Comments
 (0)