Skip to content

Commit 53aceea

Browse files
Merge branch 'main' into fix/streaming-endpoint-deadlock
2 parents ce0c04f + d2e869f commit 53aceea

File tree

16 files changed

+213
-93
lines changed

16 files changed

+213
-93
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__
66
.pytest_cache
77
.ruff_cache
88
.venv
9+
test_venv/
910
coverage.xml
1011
.nox
11-
spec.json
12+
spec.json

.ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ ignore = [
3030
"ANN003",
3131
"ANN401",
3232
"TRY003",
33-
"G004",
3433
"TRY201",
3534
"FIX002",
3635
]

src/a2a/client/auth/interceptor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ async def intercept(
6262
):
6363
headers['Authorization'] = f'Bearer {credential}'
6464
logger.debug(
65-
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
65+
"Added Bearer token for scheme '%s' (type: %s).",
66+
scheme_name,
67+
scheme_def.type,
6668
)
6769
http_kwargs['headers'] = headers
6870
return request_payload, http_kwargs
@@ -74,7 +76,9 @@ async def intercept(
7476
):
7577
headers['Authorization'] = f'Bearer {credential}'
7678
logger.debug(
77-
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
79+
"Added Bearer token for scheme '%s' (type: %s).",
80+
scheme_name,
81+
scheme_def.type,
7882
)
7983
http_kwargs['headers'] = headers
8084
return request_payload, http_kwargs
@@ -83,7 +87,8 @@ async def intercept(
8387
case APIKeySecurityScheme(in_=In.header):
8488
headers[scheme_def.name] = credential
8589
logger.debug(
86-
f"Added API Key Header for scheme '{scheme_name}'."
90+
"Added API Key Header for scheme '%s'.",
91+
scheme_name,
8792
)
8893
http_kwargs['headers'] = headers
8994
return request_payload, http_kwargs

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
GetTaskPushNotificationConfigRequest,
2929
GetTaskRequest,
3030
InternalError,
31+
InvalidParamsError,
3132
InvalidRequestError,
3233
JSONParseError,
3334
JSONRPCError,
3435
JSONRPCErrorResponse,
3536
JSONRPCRequest,
3637
JSONRPCResponse,
3738
ListTaskPushNotificationConfigRequest,
39+
MethodNotFoundError,
3840
SendMessageRequest,
3941
SendStreamingMessageRequest,
4042
SendStreamingMessageResponse,
@@ -89,6 +91,8 @@
8991
Response = Any
9092
HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any
9193

94+
MAX_CONTENT_LENGTH = 1_000_000
95+
9296

9397
class StarletteUserProxy(A2AUser):
9498
"""Adapts the Starlette User class to the A2A user representation."""
@@ -151,6 +155,25 @@ class JSONRPCApplication(ABC):
151155
(SSE).
152156
"""
153157

158+
# Method-to-model mapping for centralized routing
159+
A2ARequestModel = (
160+
SendMessageRequest
161+
| SendStreamingMessageRequest
162+
| GetTaskRequest
163+
| CancelTaskRequest
164+
| SetTaskPushNotificationConfigRequest
165+
| GetTaskPushNotificationConfigRequest
166+
| ListTaskPushNotificationConfigRequest
167+
| DeleteTaskPushNotificationConfigRequest
168+
| TaskResubscriptionRequest
169+
| GetAuthenticatedExtendedCardRequest
170+
)
171+
172+
METHOD_TO_MODEL: dict[str, type[A2ARequestModel]] = {
173+
model.model_fields['method'].default: model
174+
for model in A2ARequestModel.__args__
175+
}
176+
154177
def __init__( # noqa: PLR0913
155178
self,
156179
agent_card: AgentCard,
@@ -233,9 +256,13 @@ def _generate_error_response(
233256
)
234257
logger.log(
235258
log_level,
236-
f'Request Error (ID: {request_id}): '
237-
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
238-
f'{", Data=" + str(error_resp.error.data) if error_resp.error.data else ""}',
259+
"Request Error (ID: %s): Code=%s, Message='%s'%s",
260+
request_id,
261+
error_resp.error.code,
262+
error_resp.error.message,
263+
', Data=' + str(error_resp.error.data)
264+
if error_resp.error.data
265+
else '',
239266
)
240267
return JSONResponse(
241268
error_resp.model_dump(mode='json', exclude_none=True),
@@ -267,17 +294,60 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
267294
body = await request.json()
268295
if isinstance(body, dict):
269296
request_id = body.get('id')
297+
# Ensure request_id is valid for JSON-RPC response (str/int/None only)
298+
if request_id is not None and not isinstance(
299+
request_id, str | int
300+
):
301+
request_id = None
302+
# Treat very large payloads as invalid request (-32600) before routing
303+
with contextlib.suppress(Exception):
304+
content_length = int(request.headers.get('content-length', '0'))
305+
if content_length and content_length > MAX_CONTENT_LENGTH:
306+
return self._generate_error_response(
307+
request_id,
308+
A2AError(
309+
root=InvalidRequestError(
310+
message='Payload too large'
311+
)
312+
),
313+
)
314+
logger.debug('Request body: %s', body)
315+
# 1) Validate base JSON-RPC structure only (-32600 on failure)
316+
try:
317+
base_request = JSONRPCRequest.model_validate(body)
318+
except ValidationError as e:
319+
logger.exception('Failed to validate base JSON-RPC request')
320+
return self._generate_error_response(
321+
request_id,
322+
A2AError(
323+
root=InvalidRequestError(data=json.loads(e.json()))
324+
),
325+
)
270326

271-
# First, validate the basic JSON-RPC structure. This is crucial
272-
# because the A2ARequest model is a discriminated union where some
273-
# request types have default values for the 'method' field
274-
JSONRPCRequest.model_validate(body)
327+
# 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
328+
method = base_request.method
275329

276-
a2a_request = A2ARequest.model_validate(body)
330+
model_class = self.METHOD_TO_MODEL.get(method)
331+
if not model_class:
332+
return self._generate_error_response(
333+
request_id, A2AError(root=MethodNotFoundError())
334+
)
335+
try:
336+
specific_request = model_class.model_validate(body)
337+
except ValidationError as e:
338+
logger.exception('Failed to validate base JSON-RPC request')
339+
return self._generate_error_response(
340+
request_id,
341+
A2AError(
342+
root=InvalidParamsError(data=json.loads(e.json()))
343+
),
344+
)
277345

346+
# 3) Build call context and wrap the request for downstream handling
278347
call_context = self._context_builder.build(request)
279348

280-
request_id = a2a_request.root.id
349+
request_id = specific_request.id
350+
a2a_request = A2ARequest(root=specific_request)
281351
request_obj = a2a_request.root
282352

283353
if isinstance(
@@ -301,12 +371,6 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
301371
return self._generate_error_response(
302372
None, A2AError(root=JSONParseError(message=str(e)))
303373
)
304-
except ValidationError as e:
305-
traceback.print_exc()
306-
return self._generate_error_response(
307-
request_id,
308-
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
309-
)
310374
except HTTPException as e:
311375
if e.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE:
312376
return self._generate_error_response(
@@ -422,7 +486,7 @@ async def _process_non_streaming_request(
422486
)
423487
case _:
424488
logger.error(
425-
f'Unhandled validated request type: {type(request_obj)}'
489+
'Unhandled validated request type: %s', type(request_obj)
426490
)
427491
error = UnsupportedOperationError(
428492
message=f'Request type {type(request_obj).__name__} is unknown.'
@@ -497,8 +561,10 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
497561
"""
498562
if request.url.path == PREV_AGENT_CARD_WELL_KNOWN_PATH:
499563
logger.warning(
500-
f"Deprecated agent card endpoint '{PREV_AGENT_CARD_WELL_KNOWN_PATH}' accessed. "
501-
f"Please use '{AGENT_CARD_WELL_KNOWN_PATH}' instead. This endpoint will be removed in a future version."
564+
"Deprecated agent card endpoint '%s' accessed. "
565+
"Please use '%s' instead. This endpoint will be removed in a future version.",
566+
PREV_AGENT_CARD_WELL_KNOWN_PATH,
567+
AGENT_CARD_WELL_KNOWN_PATH,
502568
)
503569

504570
card_to_serve = self.agent_card

src/a2a/server/events/event_consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def consume_one(self) -> Event:
6262
InternalError(message='Agent did not return any response')
6363
) from e
6464

65-
logger.debug(f'Dequeued event of type: {type(event)} in consume_one.')
65+
logger.debug('Dequeued event of type: %s in consume_one.', type(event))
6666

6767
self.queue.task_done()
6868

@@ -95,7 +95,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
9595
self.queue.dequeue_event(), timeout=self._timeout
9696
)
9797
logger.debug(
98-
f'Dequeued event of type: {type(event)} in consume_all.'
98+
'Dequeued event of type: %s in consume_all.', type(event)
9999
)
100100
self.queue.task_done()
101101
logger.debug(

src/a2a/server/events/event_queue.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def enqueue_event(self, event: Event) -> None:
5454
logger.warning('Queue is closed. Event will not be enqueued.')
5555
return
5656

57-
logger.debug(f'Enqueuing event of type: {type(event)}')
57+
logger.debug('Enqueuing event of type: %s', type(event))
5858

5959
# Make sure to use put instead of put_nowait to avoid blocking the event loop.
6060
await self.queue.put(event)
@@ -103,13 +103,13 @@ async def dequeue_event(self, no_wait: bool = False) -> Event:
103103
logger.debug('Attempting to dequeue event (no_wait=True).')
104104
event = self.queue.get_nowait()
105105
logger.debug(
106-
f'Dequeued event (no_wait=True) of type: {type(event)}'
106+
'Dequeued event (no_wait=True) of type: %s', type(event)
107107
)
108108
return event
109109

110110
logger.debug('Attempting to dequeue event (waiting).')
111111
event = await self.queue.get()
112-
logger.debug(f'Dequeued event (waited) of type: {type(event)}')
112+
logger.debug('Dequeued event (waited) of type: %s', type(event))
113113
return event
114114

115115
def task_done(self) -> None:
@@ -193,7 +193,9 @@ async def clear_events(self, clear_child_queues: bool = True) -> None:
193193
while True:
194194
event = self.queue.get_nowait()
195195
logger.debug(
196-
f'Discarding unprocessed event of type: {type(event)}, content: {event}'
196+
'Discarding unprocessed event of type: %s, content: %s',
197+
type(event),
198+
event,
197199
)
198200
self.queue.task_done()
199201
cleared_count += 1
@@ -211,7 +213,8 @@ async def clear_events(self, clear_child_queues: bool = True) -> None:
211213

212214
if cleared_count > 0:
213215
logger.debug(
214-
f'Cleared {cleared_count} unprocessed events from EventQueue.'
216+
'Cleared %d unprocessed events from EventQueue.',
217+
cleared_count,
215218
)
216219

217220
# Clear all child queues (lock released before awaiting child tasks)

src/a2a/server/models.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,9 @@ def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
147147
@override
148148
def __repr__(self) -> str:
149149
"""Return a string representation of the task."""
150-
repr_template = (
151-
'<{CLS}(id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
152-
)
153-
return repr_template.format(
154-
CLS=self.__class__.__name__,
155-
ID=self.id,
156-
CTX_ID=self.context_id,
157-
STATUS=self.status,
150+
return (
151+
f'<{self.__class__.__name__}(id="{self.id}", '
152+
f'context_id="{self.context_id}", status="{self.status}")>'
158153
)
159154

160155

@@ -188,12 +183,9 @@ class TaskModel(TaskMixin, base): # type: ignore
188183
@override
189184
def __repr__(self) -> str:
190185
"""Return a string representation of the task."""
191-
repr_template = '<TaskModel[{TABLE}](id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
192-
return repr_template.format(
193-
TABLE=table_name,
194-
ID=self.id,
195-
CTX_ID=self.context_id,
196-
STATUS=self.status,
186+
return (
187+
f'<TaskModel[{table_name}](id="{self.id}", '
188+
f'context_id="{self.context_id}", status="{self.status}")>'
197189
)
198190

199191
# Set a dynamic name for better debugging
@@ -221,11 +213,9 @@ class PushNotificationConfigMixin:
221213
@override
222214
def __repr__(self) -> str:
223215
"""Return a string representation of the push notification config."""
224-
repr_template = '<{CLS}(task_id="{TID}", config_id="{CID}")>'
225-
return repr_template.format(
226-
CLS=self.__class__.__name__,
227-
TID=self.task_id,
228-
CID=self.config_id,
216+
return (
217+
f'<{self.__class__.__name__}(task_id="{self.task_id}", '
218+
f'config_id="{self.config_id}")>'
229219
)
230220

231221

@@ -241,11 +231,9 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, base): # type: i
241231
@override
242232
def __repr__(self) -> str:
243233
"""Return a string representation of the push notification config."""
244-
repr_template = '<PushNotificationConfigModel[{TABLE}](task_id="{TID}", config_id="{CID}")>'
245-
return repr_template.format(
246-
TABLE=table_name,
247-
TID=self.task_id,
248-
CID=self.config_id,
234+
return (
235+
f'<PushNotificationConfigModel[{table_name}]('
236+
f'task_id="{self.task_id}", config_id="{self.config_id}")>'
249237
)
250238

251239
PushNotificationConfigModel.__name__ = (

0 commit comments

Comments
 (0)