Skip to content

Commit 1843f8a

Browse files
authored
Merge branch 'main' into vladkol/dynamic-agent-url
2 parents 9d63223 + d2e869f commit 1843f8a

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
GetTaskPushNotificationConfigRequest,
3030
GetTaskRequest,
3131
InternalError,
32+
InvalidParamsError,
3233
InvalidRequestError,
3334
JSONParseError,
3435
JSONRPCError,
3536
JSONRPCErrorResponse,
3637
JSONRPCRequest,
3738
JSONRPCResponse,
3839
ListTaskPushNotificationConfigRequest,
40+
MethodNotFoundError,
3941
SendMessageRequest,
4042
SendStreamingMessageRequest,
4143
SendStreamingMessageResponse,
@@ -91,6 +93,8 @@
9193
URL = Any
9294
HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any
9395

96+
MAX_CONTENT_LENGTH = 1_000_000
97+
9498

9599
class StarletteUserProxy(A2AUser):
96100
"""Adapts the Starlette User class to the A2A user representation."""
@@ -153,6 +157,25 @@ class JSONRPCApplication(ABC):
153157
(SSE).
154158
"""
155159

160+
# Method-to-model mapping for centralized routing
161+
A2ARequestModel = (
162+
SendMessageRequest
163+
| SendStreamingMessageRequest
164+
| GetTaskRequest
165+
| CancelTaskRequest
166+
| SetTaskPushNotificationConfigRequest
167+
| GetTaskPushNotificationConfigRequest
168+
| ListTaskPushNotificationConfigRequest
169+
| DeleteTaskPushNotificationConfigRequest
170+
| TaskResubscriptionRequest
171+
| GetAuthenticatedExtendedCardRequest
172+
)
173+
174+
METHOD_TO_MODEL: dict[str, type[A2ARequestModel]] = {
175+
model.model_fields['method'].default: model
176+
for model in A2ARequestModel.__args__
177+
}
178+
156179
def __init__( # noqa: PLR0913
157180
self,
158181
agent_card: AgentCard,
@@ -273,17 +296,60 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
273296
body = await request.json()
274297
if isinstance(body, dict):
275298
request_id = body.get('id')
299+
# Ensure request_id is valid for JSON-RPC response (str/int/None only)
300+
if request_id is not None and not isinstance(
301+
request_id, str | int
302+
):
303+
request_id = None
304+
# Treat very large payloads as invalid request (-32600) before routing
305+
with contextlib.suppress(Exception):
306+
content_length = int(request.headers.get('content-length', '0'))
307+
if content_length and content_length > MAX_CONTENT_LENGTH:
308+
return self._generate_error_response(
309+
request_id,
310+
A2AError(
311+
root=InvalidRequestError(
312+
message='Payload too large'
313+
)
314+
),
315+
)
316+
logger.debug('Request body: %s', body)
317+
# 1) Validate base JSON-RPC structure only (-32600 on failure)
318+
try:
319+
base_request = JSONRPCRequest.model_validate(body)
320+
except ValidationError as e:
321+
logger.exception('Failed to validate base JSON-RPC request')
322+
return self._generate_error_response(
323+
request_id,
324+
A2AError(
325+
root=InvalidRequestError(data=json.loads(e.json()))
326+
),
327+
)
276328

277-
# First, validate the basic JSON-RPC structure. This is crucial
278-
# because the A2ARequest model is a discriminated union where some
279-
# request types have default values for the 'method' field
280-
JSONRPCRequest.model_validate(body)
329+
# 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
330+
method = base_request.method
281331

282-
a2a_request = A2ARequest.model_validate(body)
332+
model_class = self.METHOD_TO_MODEL.get(method)
333+
if not model_class:
334+
return self._generate_error_response(
335+
request_id, A2AError(root=MethodNotFoundError())
336+
)
337+
try:
338+
specific_request = model_class.model_validate(body)
339+
except ValidationError as e:
340+
logger.exception('Failed to validate base JSON-RPC request')
341+
return self._generate_error_response(
342+
request_id,
343+
A2AError(
344+
root=InvalidParamsError(data=json.loads(e.json()))
345+
),
346+
)
283347

348+
# 3) Build call context and wrap the request for downstream handling
284349
call_context = self._context_builder.build(request)
285350

286-
request_id = a2a_request.root.id
351+
request_id = specific_request.id
352+
a2a_request = A2ARequest(root=specific_request)
287353
request_obj = a2a_request.root
288354

289355
if isinstance(
@@ -307,12 +373,6 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
307373
return self._generate_error_response(
308374
None, A2AError(root=JSONParseError(message=str(e)))
309375
)
310-
except ValidationError as e:
311-
traceback.print_exc()
312-
return self._generate_error_response(
313-
request_id,
314-
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
315-
)
316376
except HTTPException as e:
317377
if e.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE:
318378
return self._generate_error_response(

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
MessageSendParams,
3737
Task,
3838
TaskIdParams,
39+
TaskNotCancelableError,
3940
TaskNotFoundError,
4041
TaskPushNotificationConfig,
4142
TaskQueryParams,
@@ -111,6 +112,26 @@ async def on_get_task(
111112
task: Task | None = await self.task_store.get(params.id)
112113
if not task:
113114
raise ServerError(error=TaskNotFoundError())
115+
116+
# Apply historyLength parameter if specified
117+
if params.history_length is not None and task.history:
118+
# Limit history to the most recent N messages
119+
limited_history = (
120+
task.history[-params.history_length :]
121+
if params.history_length > 0
122+
else []
123+
)
124+
# Create a new task instance with limited history
125+
task = Task(
126+
id=task.id,
127+
context_id=task.context_id,
128+
status=task.status,
129+
artifacts=task.artifacts,
130+
history=limited_history,
131+
metadata=task.metadata,
132+
kind=task.kind,
133+
)
134+
114135
return task
115136

116137
async def on_cancel_task(
@@ -124,6 +145,14 @@ async def on_cancel_task(
124145
if not task:
125146
raise ServerError(error=TaskNotFoundError())
126147

148+
# Check if task is in a non-cancelable state (completed, canceled, failed, rejected)
149+
if task.status.state in TERMINAL_TASK_STATES:
150+
raise ServerError(
151+
error=TaskNotCancelableError(
152+
message=f'Task cannot be canceled - current state: {task.status.state}'
153+
)
154+
)
155+
127156
task_manager = TaskManager(
128157
task_id=task.id,
129158
context_id=task.context_id,

tests/server/test_integration.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
Artifact,
3030
DataPart,
3131
InternalError,
32+
InvalidParamsError,
3233
InvalidRequestError,
3334
JSONParseError,
3435
Message,
36+
MethodNotFoundError,
3537
Part,
3638
PushNotificationConfig,
3739
Role,
@@ -837,7 +839,7 @@ def test_invalid_request_structure(client: TestClient):
837839
response = client.post(
838840
'/',
839841
json={
840-
# Missing required fields
842+
'jsonrpc': 'aaaa', # Missing or wrong required fields
841843
'id': '123',
842844
'method': 'foo/bar',
843845
},
@@ -976,7 +978,7 @@ def test_unknown_method(client: TestClient):
976978
data = response.json()
977979
assert 'error' in data
978980
# This should produce an UnsupportedOperationError error code
979-
assert data['error']['code'] == InvalidRequestError().code
981+
assert data['error']['code'] == MethodNotFoundError().code
980982

981983

982984
def test_validation_error(client: TestClient):
@@ -987,7 +989,7 @@ def test_validation_error(client: TestClient):
987989
json={
988990
'jsonrpc': '2.0',
989991
'id': '123',
990-
'method': 'messages/send',
992+
'method': 'message/send',
991993
'params': {
992994
'message': {
993995
# Missing required fields
@@ -999,7 +1001,7 @@ def test_validation_error(client: TestClient):
9991001
assert response.status_code == 200
10001002
data = response.json()
10011003
assert 'error' in data
1002-
assert data['error']['code'] == InvalidRequestError().code
1004+
assert data['error']['code'] == InvalidParamsError().code
10031005

10041006

10051007
def test_unhandled_exception(client: TestClient, handler: mock.AsyncMock):

0 commit comments

Comments
 (0)