Skip to content

Commit ee32329

Browse files
committed
fix(server): send proxy errors as a2a errors
Signed-off-by: Radek Ježek <radek.jezek@ibm.com>
1 parent 2884913 commit ee32329

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

apps/agentstack-server/src/agentstack_server/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def __init__(
6969
super().__init__("Insufficient permissions", status_code)
7070

7171

72+
class InvalidProviderCallError(PlatformError): ...
73+
74+
7275
class InvalidVectorDimensionError(PlatformError): ...
7376

7477

apps/agentstack-server/src/agentstack_server/service_layer/services/a2a.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import logging
77
import uuid
8-
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable
8+
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, Callable
99
from contextlib import asynccontextmanager
1010
from datetime import timedelta
1111
from typing import NamedTuple, cast
@@ -24,6 +24,7 @@
2424
AgentCard,
2525
DeleteTaskPushNotificationConfigParams,
2626
GetTaskPushNotificationConfigParams,
27+
InternalError,
2728
InvalidRequestError,
2829
ListTaskPushNotificationConfigParams,
2930
Message,
@@ -47,9 +48,10 @@
4748
NetworkProviderLocation,
4849
Provider,
4950
ProviderDeploymentState,
51+
UnmanagedState,
5052
)
5153
from agentstack_server.domain.models.user import User
52-
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError
54+
from agentstack_server.exceptions import EntityNotFoundError, ForbiddenUpdateError, InvalidProviderCallError
5355
from agentstack_server.service_layer.deployment_manager import (
5456
IProviderDeploymentManager,
5557
)
@@ -116,6 +118,10 @@ async def _fn(*args, **kwargs):
116118
raise ServerError(error=InvalidRequestError(message=str(e))) from e
117119
except A2AClientJSONRPCError as e:
118120
raise ServerError(error=e.error) from e
121+
except InvalidProviderCallError as e:
122+
raise ServerError(error=InvalidRequestError(message=f"Invalid request to agent: {e!r}")) from e
123+
except Exception as e:
124+
raise ServerError(error=InternalError(message=f"Internal error: {e!r}")) from e
119125

120126
@functools.wraps(fn)
121127
async def _fn_iter(*args, **kwargs):
@@ -131,18 +137,28 @@ async def _fn_iter(*args, **kwargs):
131137
class ProxyRequestHandler(RequestHandler):
132138
def __init__(
133139
self,
134-
agent_card: AgentCard,
140+
*,
135141
provider_id: UUID,
136142
uow: IUnitOfWorkFactory,
137143
user: User,
144+
# Calling the factory have side-effects, such as rotating the agent
145+
agent_card_factory: Callable[[], Awaitable[AgentCard]] | None = None,
146+
agent_card: AgentCard | None = None,
138147
):
148+
if agent_card_factory is None and agent_card is None:
149+
raise ValueError("One of agent_card_factory or agent_card must be provided")
150+
self._agent_card_factory = agent_card_factory
139151
self._agent_card = agent_card
140152
self._provider_id = provider_id
141153
self._user = user
142154
self._uow = uow
143155

144156
@asynccontextmanager
145157
async def _client_transport(self) -> AsyncIterator[ClientTransport]:
158+
if self._agent_card is None:
159+
assert self._agent_card_factory is not None
160+
self._agent_card = await self._agent_card_factory()
161+
146162
async with httpx.AsyncClient(follow_redirects=True, timeout=timedelta(hours=1).total_seconds()) as httpx_client:
147163
client: BaseClient = cast(
148164
BaseClient,
@@ -306,10 +322,13 @@ def __init__(
306322
self._expire_requests_after = timedelta(days=configuration.a2a_proxy.requests_expire_after_days)
307323

308324
async def get_request_handler(self, *, provider: Provider, user: User) -> RequestHandler:
309-
url = await self.ensure_agent(provider_id=provider.id)
310-
agent_card = create_deployment_agent_card(provider.agent_card, deployment_base=str(url))
325+
async def agent_card_factory() -> AgentCard:
326+
# Delay ensure_agent to the handler so that errors are wrapped properly
327+
url = await self.ensure_agent(provider_id=provider.id)
328+
return create_deployment_agent_card(provider.agent_card, deployment_base=str(url))
329+
311330
return ProxyRequestHandler(
312-
agent_card=agent_card,
331+
agent_card_factory=agent_card_factory,
313332
provider_id=provider.id,
314333
uow=self._uow,
315334
user=user,
@@ -332,6 +351,11 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
332351
await uow.commit()
333352

334353
if not provider.managed:
354+
if provider.unmanaged_state is UnmanagedState.OFFLINE:
355+
raise InvalidProviderCallError(
356+
f"Cannot send message to provider {provider_id}: provider is offline"
357+
)
358+
335359
assert isinstance(provider.source, NetworkProviderLocation)
336360
return provider.source.a2a_url
337361

@@ -340,7 +364,9 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
340364
should_wait = False
341365
match state:
342366
case ProviderDeploymentState.ERROR:
343-
raise RuntimeError("Provider is in an error state")
367+
raise InvalidProviderCallError(
368+
f"Cannot send message to provider {provider_id}: provider is in an error state"
369+
)
344370
case (
345371
ProviderDeploymentState.MISSING
346372
| ProviderDeploymentState.RUNNING

apps/agentstack-server/template.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ OCI_REGISTRY__AGENTSTACK-REGISTRY-SVC.DEFAULT:5001__INSECURE=true
3636
# AUTH__BASIC__ENABLED=true
3737
# AUTH__BASIC__ADMIN_PASSWORD=test-password
3838
# AUTH__BASIC__ADMIN_PASSWORD=test-password
39-
# AUTH__BASIC__JWT_SECRET_KEY=test-key
39+
# AUTH__JWT_SECRET_KEY=test-key
4040

4141
# AUTH__OIDC__ENABLED=true
4242
# AUTH__OIDC__ADMIN_EMAILS='["abc@ibm.com"]'

apps/agentstack-server/tests/e2e/routes/test_a2a_proxy.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,3 +1195,36 @@ async def test_task_and_context_both_specified_single_query(client: Client, hand
11951195
{"context_id": "dual-context-456"},
11961196
)
11971197
assert context_result.fetchone() is not None
1198+
1199+
1200+
async def test_invalid_request_raises_a2a_error(client: Client, handler: mock.AsyncMock, db_transaction):
1201+
"""Test that an invalid request to an offline provider returns an A2A error."""
1202+
1203+
# set provider as offline
1204+
provider_id = str(client.base_url).rstrip("/").split("/")[-1]
1205+
await db_transaction.execute(
1206+
text("UPDATE providers SET unmanaged_state = 'offline' WHERE id = :provider_id"),
1207+
{"provider_id": provider_id},
1208+
)
1209+
await db_transaction.commit()
1210+
1211+
message_data = {
1212+
"jsonrpc": "2.0",
1213+
"id": "123",
1214+
"method": "message/send",
1215+
"params": {
1216+
"message": {
1217+
"role": "agent",
1218+
"parts": [{"kind": "text", "text": "Hello"}],
1219+
"messageId": "111",
1220+
"kind": "message",
1221+
}
1222+
},
1223+
}
1224+
response = client.post("/", json=message_data)
1225+
assert response.status_code == 200
1226+
data = response.json()
1227+
assert data["id"] == "123"
1228+
assert "error" in data
1229+
assert data["error"]["code"] == InvalidRequestError().code
1230+
assert "provider is offline" in data["error"]["message"]

0 commit comments

Comments
 (0)