Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 99 additions & 12 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
InternalError,
InvalidParamsError,
InvalidRequestError,
JSONParseError,
JSONRPCError,
JSONRPCErrorResponse,
JSONRPCRequest,
JSONRPCResponse,
ListTaskPushNotificationConfigRequest,
MethodNotFoundError,
SendMessageRequest,
SendStreamingMessageRequest,
SendStreamingMessageResponse,
Expand Down Expand Up @@ -89,6 +91,8 @@
Response = Any
HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any

MAX_CONTENT_LENGTH = 1_000_000


class StarletteUserProxy(A2AUser):
"""Adapts the Starlette User class to the A2A user representation."""
Expand Down Expand Up @@ -151,35 +155,81 @@
(SSE).
"""

# Method-to-model mapping for centralized routing
# Define the union type for all supported request models
A2ARequestModel = (
SendMessageRequest
| SendStreamingMessageRequest
| GetTaskRequest
| CancelTaskRequest
| SetTaskPushNotificationConfigRequest
| GetTaskPushNotificationConfigRequest
| ListTaskPushNotificationConfigRequest
| DeleteTaskPushNotificationConfigRequest
| TaskResubscriptionRequest
| GetAuthenticatedExtendedCardRequest
)

# Pydantic model fields like 'method' are instance attributes, not class attributes.
# So, 'Type.method' does not exist until you instantiate the model.
# To get the default value for the 'method' field at the class level,
# you must use Type.model_fields["method"].default.
METHOD_TO_MODEL: dict[str, type[A2ARequestModel]] = {
SendMessageRequest.model_fields['method'].default: SendMessageRequest,
SendStreamingMessageRequest.model_fields[
'method'
].default: SendStreamingMessageRequest,
GetTaskRequest.model_fields['method'].default: GetTaskRequest,
CancelTaskRequest.model_fields['method'].default: CancelTaskRequest,
SetTaskPushNotificationConfigRequest.model_fields[
'method'
].default: SetTaskPushNotificationConfigRequest,
GetTaskPushNotificationConfigRequest.model_fields[
'method'
].default: GetTaskPushNotificationConfigRequest,
ListTaskPushNotificationConfigRequest.model_fields[
'method'
].default: ListTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest.model_fields[
'method'
].default: DeleteTaskPushNotificationConfigRequest,
TaskResubscriptionRequest.model_fields[
'method'
].default: TaskResubscriptionRequest,
GetAuthenticatedExtendedCardRequest.model_fields[
'method'
].default: GetAuthenticatedExtendedCardRequest,
}

def __init__( # noqa: PLR0913
self,
agent_card: AgentCard,
http_handler: RequestHandler,
extended_agent_card: AgentCard | None = None,
context_builder: CallContextBuilder | None = None,
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
extended_card_modifier: Callable[
[AgentCard, ServerCallContext], AgentCard
]
| None = None,
) -> None:
"""Initializes the JSONRPCApplication.

Args:
agent_card: The AgentCard describing the agent's capabilities.
http_handler: The handler instance responsible for processing A2A
requests via http.
extended_agent_card: An optional, distinct AgentCard to be served
at the authenticated extended card endpoint.
context_builder: The CallContextBuilder used to construct the
ServerCallContext passed to the http_handler. If None, no
ServerCallContext is passed.
card_modifier: An optional callback to dynamically modify the public
agent card before it is served.
extended_card_modifier: An optional callback to dynamically modify
the extended agent card before it is served. It receives the
call context.
"""

Check notice on line 232 in src/a2a/server/apps/jsonrpc/jsonrpc_app.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/rest/rest_adapter.py (55-79)
if not _package_starlette_installed:
raise ImportError(
'Packages `starlette` and `sse-starlette` are required to use the'
Expand Down Expand Up @@ -267,17 +317,60 @@
body = await request.json()
if isinstance(body, dict):
request_id = body.get('id')
# Ensure request_id is valid for JSON-RPC response (str/int/None only)
if request_id is not None and not isinstance(
request_id, str | int
):
request_id = None
# Treat very large payloads as invalid request (-32600) before routing
with contextlib.suppress(Exception):
content_length = int(request.headers.get('content-length', '0'))
if content_length and content_length > MAX_CONTENT_LENGTH:
return self._generate_error_response(
request_id,
A2AError(
root=InvalidRequestError(
message='Payload too large'
)
),
)
logger.debug(f'Request body: {body}')
# 1) Validate base JSON-RPC structure only (-32600 on failure)
try:
base_request = JSONRPCRequest.model_validate(body)
except ValidationError as e:
logger.exception('Failed to validate base JSON-RPC request')
return self._generate_error_response(
request_id,
A2AError(
root=InvalidRequestError(data=json.loads(e.json()))
),
)

# First, validate the basic JSON-RPC structure. This is crucial
# because the A2ARequest model is a discriminated union where some
# request types have default values for the 'method' field
JSONRPCRequest.model_validate(body)
# 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
method = base_request.method

a2a_request = A2ARequest.model_validate(body)
model_class = self.METHOD_TO_MODEL.get(method)
if not model_class:
return self._generate_error_response(
request_id, A2AError(root=MethodNotFoundError())
)
try:
specific_request = model_class.model_validate(body)
except ValidationError as e:
logger.exception('Failed to validate base JSON-RPC request')
return self._generate_error_response(
request_id,
A2AError(
root=InvalidParamsError(data=json.loads(e.json()))
),
)

# 3) Build call context and wrap the request for downstream handling
call_context = self._context_builder.build(request)

request_id = a2a_request.root.id
request_id = specific_request.id
a2a_request = A2ARequest(root=specific_request)
request_obj = a2a_request.root

if isinstance(
Expand All @@ -301,12 +394,6 @@
return self._generate_error_response(
None, A2AError(root=JSONParseError(message=str(e)))
)
except ValidationError as e:
traceback.print_exc()
return self._generate_error_response(
request_id,
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
)
except HTTPException as e:
if e.status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE:
return self._generate_error_response(
Expand Down
29 changes: 29 additions & 0 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
MessageSendParams,
Task,
TaskIdParams,
TaskNotCancelableError,
TaskNotFoundError,
TaskPushNotificationConfig,
TaskQueryParams,
Expand Down Expand Up @@ -111,6 +112,26 @@
task: Task | None = await self.task_store.get(params.id)
if not task:
raise ServerError(error=TaskNotFoundError())

# Apply historyLength parameter if specified
if params.history_length is not None and task.history:
# Limit history to the most recent N messages
limited_history = (
task.history[-params.history_length :]
if params.history_length > 0
else []
)
# Create a new task instance with limited history
task = Task(
id=task.id,
context_id=task.context_id,
status=task.status,
artifacts=task.artifacts,
history=limited_history,
metadata=task.metadata,
kind=task.kind,
)

return task

async def on_cancel_task(
Expand All @@ -124,17 +145,25 @@
if not task:
raise ServerError(error=TaskNotFoundError())

# Check if task is in a non-cancelable state (completed, canceled, failed, rejected)
if task.status.state in TERMINAL_TASK_STATES:
raise ServerError(
error=TaskNotCancelableError(
message=f'Task cannot be canceled - current state: {task.status.state}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
)
result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
queue = EventQueue()

Check notice on line 166 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (478-492)

await self.agent_executor.cancel(
RequestContext(
Expand Down
10 changes: 6 additions & 4 deletions tests/server/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
Artifact,
DataPart,
InternalError,
InvalidParamsError,
InvalidRequestError,
JSONParseError,
Message,
MethodNotFoundError,
Part,
PushNotificationConfig,
Role,
Expand Down Expand Up @@ -837,7 +839,7 @@ def test_invalid_request_structure(client: TestClient):
response = client.post(
'/',
json={
# Missing required fields
'jsonrpc': 'aaaa', # Missing or wrong required fields
'id': '123',
'method': 'foo/bar',
},
Expand Down Expand Up @@ -976,7 +978,7 @@ def test_unknown_method(client: TestClient):
data = response.json()
assert 'error' in data
# This should produce an UnsupportedOperationError error code
assert data['error']['code'] == InvalidRequestError().code
assert data['error']['code'] == MethodNotFoundError().code


def test_validation_error(client: TestClient):
Expand All @@ -987,7 +989,7 @@ def test_validation_error(client: TestClient):
json={
'jsonrpc': '2.0',
'id': '123',
'method': 'messages/send',
'method': 'message/send',
'params': {
'message': {
# Missing required fields
Expand All @@ -999,7 +1001,7 @@ def test_validation_error(client: TestClient):
assert response.status_code == 200
data = response.json()
assert 'error' in data
assert data['error']['code'] == InvalidRequestError().code
assert data['error']['code'] == InvalidParamsError().code


def test_unhandled_exception(client: TestClient, handler: mock.AsyncMock):
Expand Down
Loading