55import inspect
66import logging
77import uuid
8- from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Callable
9- from contextlib import asynccontextmanager
8+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Awaitable , Callable , Iterator
9+ from contextlib import asynccontextmanager , contextmanager
1010from datetime import timedelta
1111from typing import NamedTuple , cast
1212from urllib .parse import urljoin , urlparse
2424 AgentCard ,
2525 DeleteTaskPushNotificationConfigParams ,
2626 GetTaskPushNotificationConfigParams ,
27+ InternalError ,
2728 InvalidRequestError ,
2829 ListTaskPushNotificationConfigParams ,
2930 Message ,
4748 NetworkProviderLocation ,
4849 Provider ,
4950 ProviderDeploymentState ,
51+ UnmanagedState ,
5052)
5153from agentstack_server .domain .models .user import User
52- from agentstack_server .exceptions import EntityNotFoundError , ForbiddenUpdateError
54+ from agentstack_server .exceptions import EntityNotFoundError , ForbiddenUpdateError , InvalidProviderCallError
5355from agentstack_server .service_layer .deployment_manager import (
5456 IProviderDeploymentManager ,
5557)
@@ -104,10 +106,10 @@ class A2AServerResponse(NamedTuple):
104106
105107
106108def _handle_exception [T : Callable ](fn : T ) -> T :
107- @functools . wraps ( fn )
108- async def _fn ( * args , ** kwargs ) :
109+ @contextmanager
110+ def _handle_exception_impl () -> Iterator [ None ] :
109111 try :
110- return await fn ( * args , ** kwargs )
112+ yield
111113 except EntityNotFoundError as e :
112114 if "task" in e .entity :
113115 raise ServerError (error = TaskNotFoundError ()) from e
@@ -116,33 +118,50 @@ async def _fn(*args, **kwargs):
116118 raise ServerError (error = InvalidRequestError (message = str (e ))) from e
117119 except A2AClientJSONRPCError as e :
118120 raise ServerError (error = e .error ) from e
121+ except InvalidProviderCallError as e :
122+ raise ServerError (error = InvalidRequestError (message = f"Invalid request to agent: { e !r} " )) from e
123+ except Exception as e :
124+ raise ServerError (error = InternalError (message = f"Internal error: { e !r} " )) from e
125+
126+ @functools .wraps (fn )
127+ async def _fn (* args , ** kwargs ):
128+ with _handle_exception_impl ():
129+ return await fn (* args , ** kwargs )
119130
120131 @functools .wraps (fn )
121132 async def _fn_iter (* args , ** kwargs ):
122- try :
133+ with _handle_exception_impl () :
123134 async for item in fn (* args , ** kwargs ):
124135 yield item
125- except A2AClientJSONRPCError as e :
126- raise ServerError (error = e .error ) from e
127136
128137 return _fn_iter if inspect .isasyncgenfunction (fn ) else _fn # pyright: ignore [reportReturnType]
129138
130139
131140class ProxyRequestHandler (RequestHandler ):
132141 def __init__ (
133142 self ,
134- agent_card : AgentCard ,
143+ * ,
135144 provider_id : UUID ,
136145 uow : IUnitOfWorkFactory ,
137146 user : User ,
147+ # Calling the factory have side-effects, such as rotating the agent
148+ agent_card_factory : Callable [[], Awaitable [AgentCard ]] | None = None ,
149+ agent_card : AgentCard | None = None ,
138150 ):
151+ if agent_card_factory is None and agent_card is None :
152+ raise ValueError ("One of agent_card_factory or agent_card must be provided" )
153+ self ._agent_card_factory = agent_card_factory
139154 self ._agent_card = agent_card
140155 self ._provider_id = provider_id
141156 self ._user = user
142157 self ._uow = uow
143158
144159 @asynccontextmanager
145160 async def _client_transport (self ) -> AsyncIterator [ClientTransport ]:
161+ if self ._agent_card is None :
162+ assert self ._agent_card_factory is not None
163+ self ._agent_card = await self ._agent_card_factory ()
164+
146165 async with httpx .AsyncClient (follow_redirects = True , timeout = timedelta (hours = 1 ).total_seconds ()) as httpx_client :
147166 client : BaseClient = cast (
148167 BaseClient ,
@@ -306,10 +325,13 @@ def __init__(
306325 self ._expire_requests_after = timedelta (days = configuration .a2a_proxy .requests_expire_after_days )
307326
308327 async def get_request_handler (self , * , provider : Provider , user : User ) -> RequestHandler :
309- url = await self .ensure_agent (provider_id = provider .id )
310- agent_card = create_deployment_agent_card (provider .agent_card , deployment_base = str (url ))
328+ async def agent_card_factory () -> AgentCard :
329+ # Delay ensure_agent to the handler so that errors are wrapped properly
330+ url = await self .ensure_agent (provider_id = provider .id )
331+ return create_deployment_agent_card (provider .agent_card , deployment_base = str (url ))
332+
311333 return ProxyRequestHandler (
312- agent_card = agent_card ,
334+ agent_card_factory = agent_card_factory ,
313335 provider_id = provider .id ,
314336 uow = self ._uow ,
315337 user = user ,
@@ -332,6 +354,11 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
332354 await uow .commit ()
333355
334356 if not provider .managed :
357+ if provider .unmanaged_state is UnmanagedState .OFFLINE :
358+ raise InvalidProviderCallError (
359+ f"Cannot send message to provider { provider_id } : provider is offline"
360+ )
361+
335362 assert isinstance (provider .source , NetworkProviderLocation )
336363 return provider .source .a2a_url
337364
@@ -340,7 +367,9 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
340367 should_wait = False
341368 match state :
342369 case ProviderDeploymentState .ERROR :
343- raise RuntimeError ("Provider is in an error state" )
370+ raise InvalidProviderCallError (
371+ f"Cannot send message to provider { provider_id } : provider is in an error state"
372+ )
344373 case (
345374 ProviderDeploymentState .MISSING
346375 | ProviderDeploymentState .RUNNING
0 commit comments