55import inspect
66import logging
77import uuid
8- from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Callable
8+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Awaitable , Callable
99from contextlib import asynccontextmanager
1010from datetime import timedelta
1111from typing import NamedTuple , cast
2424 AgentCard ,
2525 DeleteTaskPushNotificationConfigParams ,
2626 GetTaskPushNotificationConfigParams ,
27+ InternalError ,
2728 InvalidRequestError ,
2829 ListTaskPushNotificationConfigParams ,
2930 Message ,
4748 NetworkProviderLocation ,
4849 Provider ,
4950 ProviderDeploymentState ,
51+ UnmanagedState ,
5052)
5153from agentstack_server .domain .models .user import User
52- from agentstack_server .exceptions import EntityNotFoundError , ForbiddenUpdateError
54+ from agentstack_server .exceptions import EntityNotFoundError , ForbiddenUpdateError , InvalidProviderCallError
5355from agentstack_server .service_layer .deployment_manager import (
5456 IProviderDeploymentManager ,
5557)
@@ -116,6 +118,10 @@ async def _fn(*args, **kwargs):
116118 raise ServerError (error = InvalidRequestError (message = str (e ))) from e
117119 except A2AClientJSONRPCError as e :
118120 raise ServerError (error = e .error ) from e
121+ except InvalidProviderCallError as e :
122+ raise ServerError (error = InvalidRequestError (message = f"Invalid request to agent: { e !r} " )) from e
123+ except Exception as e :
124+ raise ServerError (error = InternalError (message = f"Internal error: { e !r} " )) from e
119125
120126 @functools .wraps (fn )
121127 async def _fn_iter (* args , ** kwargs ):
@@ -131,18 +137,28 @@ async def _fn_iter(*args, **kwargs):
131137class ProxyRequestHandler (RequestHandler ):
132138 def __init__ (
133139 self ,
134- agent_card : AgentCard ,
140+ * ,
135141 provider_id : UUID ,
136142 uow : IUnitOfWorkFactory ,
137143 user : User ,
144+ # Calling the factory have side-effects, such as rotating the agent
145+ agent_card_factory : Callable [[], Awaitable [AgentCard ]] | None = None ,
146+ agent_card : AgentCard | None = None ,
138147 ):
148+ if agent_card_factory is None and agent_card is None :
149+ raise ValueError ("One of agent_card_factory or agent_card must be provided" )
150+ self ._agent_card_factory = agent_card_factory
139151 self ._agent_card = agent_card
140152 self ._provider_id = provider_id
141153 self ._user = user
142154 self ._uow = uow
143155
144156 @asynccontextmanager
145157 async def _client_transport (self ) -> AsyncIterator [ClientTransport ]:
158+ if self ._agent_card is None :
159+ assert self ._agent_card_factory is not None
160+ self ._agent_card = await self ._agent_card_factory ()
161+
146162 async with httpx .AsyncClient (follow_redirects = True , timeout = timedelta (hours = 1 ).total_seconds ()) as httpx_client :
147163 client : BaseClient = cast (
148164 BaseClient ,
@@ -306,10 +322,13 @@ def __init__(
306322 self ._expire_requests_after = timedelta (days = configuration .a2a_proxy .requests_expire_after_days )
307323
308324 async def get_request_handler (self , * , provider : Provider , user : User ) -> RequestHandler :
309- url = await self .ensure_agent (provider_id = provider .id )
310- agent_card = create_deployment_agent_card (provider .agent_card , deployment_base = str (url ))
325+ async def agent_card_factory () -> AgentCard :
326+ # Delay ensure_agent to the handler so that errors are wrapped properly
327+ url = await self .ensure_agent (provider_id = provider .id )
328+ return create_deployment_agent_card (provider .agent_card , deployment_base = str (url ))
329+
311330 return ProxyRequestHandler (
312- agent_card = agent_card ,
331+ agent_card_factory = agent_card_factory ,
313332 provider_id = provider .id ,
314333 uow = self ._uow ,
315334 user = user ,
@@ -332,6 +351,11 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
332351 await uow .commit ()
333352
334353 if not provider .managed :
354+ if provider .unmanaged_state is UnmanagedState .OFFLINE :
355+ raise InvalidProviderCallError (
356+ f"Cannot send message to provider { provider_id } : provider is offline"
357+ )
358+
335359 assert isinstance (provider .source , NetworkProviderLocation )
336360 return provider .source .a2a_url
337361
@@ -340,7 +364,9 @@ async def ensure_agent(self, *, provider_id: UUID) -> HttpUrl:
340364 should_wait = False
341365 match state :
342366 case ProviderDeploymentState .ERROR :
343- raise RuntimeError ("Provider is in an error state" )
367+ raise InvalidProviderCallError (
368+ f"Cannot send message to provider { provider_id } : provider is in an error state"
369+ )
344370 case (
345371 ProviderDeploymentState .MISSING
346372 | ProviderDeploymentState .RUNNING
0 commit comments