Skip to content

Commit d0bc883

Browse files
Merge branch 'main' into add-fastapi-app
2 parents 54747f8 + 1107151 commit d0bc883

File tree

5 files changed

+111
-13
lines changed

5 files changed

+111
-13
lines changed

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,20 @@
4545

4646
logger = logging.getLogger(__name__)
4747

48-
# Register Starlette User as an implementation of a2a.auth.user.User
49-
A2AUser.register(BaseUser)
48+
49+
class StarletteUserProxy(A2AUser):
50+
"""Adapts the Starlette User class to the A2A user representation."""
51+
52+
def __init__(self, user: BaseUser):
53+
self._user = user
54+
55+
@property
56+
def is_authenticated(self):
57+
return self._user.is_authenticated
58+
59+
@property
60+
def user_name(self):
61+
return self._user.display_name
5062

5163

5264
class CallContextBuilder(ABC):
@@ -64,7 +76,7 @@ def build(self, request: Request) -> ServerCallContext:
6476
user = UnauthenticatedUser()
6577
state = {}
6678
with contextlib.suppress(Exception):
67-
user = request.user
79+
user = StarletteUserProxy(request.user)
6880
state['auth'] = request.auth
6981
return ServerCallContext(user=user, state=state)
7082

@@ -139,7 +151,7 @@ def _generate_error_response(
139151
log_level,
140152
f'Request Error (ID: {request_id}): '
141153
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
142-
f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}',
154+
f'{", Data=" + str(error_resp.error.data) if error_resp.error.data else ""}',
143155
)
144156
return JSONResponse(
145157
error_resp.model_dump(mode='json', exclude_none=True),

src/a2a/server/tasks/task_updater.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22

3+
from datetime import datetime, timezone
34
from typing import Any
45

56
from a2a.server.events import EventQueue
@@ -34,15 +35,23 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3435
self.context_id = context_id
3536

3637
def update_status(
37-
self, state: TaskState, message: Message | None = None, final=False
38+
self,
39+
state: TaskState,
40+
message: Message | None = None,
41+
final=False,
42+
timestamp: str | None = None,
3843
):
3944
"""Updates the status of the task and publishes a `TaskStatusUpdateEvent`.
4045
4146
Args:
4247
state: The new state of the task.
4348
message: An optional message associated with the status update.
4449
final: If True, indicates this is the final status update for the task.
50+
timestamp: Optional ISO 8601 datetime string. Defaults to current time.
4551
"""
52+
current_timestamp = (
53+
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
54+
)
4655
self.event_queue.enqueue_event(
4756
TaskStatusUpdateEvent(
4857
taskId=self.task_id,
@@ -51,6 +60,7 @@ def update_status(
5160
status=TaskStatus(
5261
state=state,
5362
message=message,
63+
timestamp=current_timestamp,
5464
),
5565
)
5666
)

src/a2a/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,10 @@ class PushNotificationConfig(BaseModel):
583583
"""
584584

585585
authentication: PushNotificationAuthenticationInfo | None = None
586+
id: str | None = None
587+
"""
588+
Push Notification ID - created by server to support multiple callbacks
589+
"""
586590
token: str | None = None
587591
"""
588592
Token unique to this task/session.

src/a2a/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None:
6969

7070
# Find existing artifact by its id
7171
for i, art in enumerate(task.artifacts):
72-
if hasattr(art, 'artifactId') and art.artifactId == artifact_id:
72+
if art.artifactId == artifact_id:
7373
existing_artifact = art
7474
existing_artifact_list_index = i
7575
break

tests/server/test_integration.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55

66
import pytest
77

8+
from starlette.authentication import (
9+
AuthCredentials,
10+
AuthenticationBackend,
11+
BaseUser,
12+
SimpleUser,
13+
)
14+
from starlette.middleware import Middleware
15+
from starlette.middleware.authentication import AuthenticationMiddleware
16+
from starlette.requests import HTTPConnection
817
from starlette.responses import JSONResponse
918
from starlette.routing import Route
1019
from starlette.testclient import TestClient
@@ -21,8 +30,12 @@
2130
InternalError,
2231
InvalidRequestError,
2332
JSONParseError,
33+
Message,
2434
Part,
2535
PushNotificationConfig,
36+
Role,
37+
SendMessageResponse,
38+
SendMessageSuccessResponse,
2639
Task,
2740
TaskArtifactUpdateEvent,
2841
TaskPushNotificationConfig,
@@ -122,21 +135,20 @@ def handler():
122135
def starlette_app(agent_card: AgentCard, handler: mock.AsyncMock):
123136
return A2AStarletteApplication(agent_card, handler)
124137

125-
126138
@pytest.fixture
127-
def starlette_client(app: A2AStarletteApplication):
139+
def starlette_client(app: A2AStarletteApplication, **kwargs):
128140
"""Create a test client with the Starlette app."""
129-
return TestClient(app.build())
141+
return TestClient(app.build(**kwargs))
142+
130143

131144
@pytest.fixture
132145
def fastapi_app(agent_card: AgentCard, handler: mock.AsyncMock):
133146
return A2AFastAPIApplication(agent_card, handler)
134147

135-
136148
@pytest.fixture
137-
def fastapi_client(app: A2AFastAPIApplication):
149+
def fastapi_client(app: A2AFastAPIApplication, **kwargs):
138150
"""Create a test client with the FastAPI app."""
139-
return TestClient(app.build())
151+
return TestClient(app.build(**kwargs))
140152

141153

142154
# === BASIC FUNCTIONALITY TESTS ===
@@ -345,7 +357,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock):
345357
mock_task = Task(
346358
id='task1',
347359
contextId='session-xyz',
348-
state='completed',
349360
status=task_status,
350361
)
351362
handler.on_message_send.return_value = mock_task
@@ -514,6 +525,67 @@ def test_get_push_notification_config(
514525
handler.on_get_task_push_notification_config.assert_awaited_once()
515526

516527

528+
def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock):
529+
class TestAuthMiddleware(AuthenticationBackend):
530+
async def authenticate(
531+
self, conn: HTTPConnection
532+
) -> tuple[AuthCredentials, BaseUser] | None:
533+
# For the purposes of this test, all requests are authenticated!
534+
return (AuthCredentials(['authenticated']), SimpleUser('test_user'))
535+
536+
client = TestClient(
537+
app.build(
538+
middleware=[
539+
Middleware(
540+
AuthenticationMiddleware, backend=TestAuthMiddleware()
541+
)
542+
]
543+
)
544+
)
545+
546+
# Set the output message to be the authenticated user name
547+
handler.on_message_send.side_effect = lambda params, context: Message(
548+
contextId='session-xyz',
549+
messageId='112',
550+
role=Role.agent,
551+
parts=[
552+
Part(TextPart(text=context.user.user_name)),
553+
],
554+
)
555+
556+
# Send request
557+
response = client.post(
558+
'/',
559+
json={
560+
'jsonrpc': '2.0',
561+
'id': '123',
562+
'method': 'message/send',
563+
'params': {
564+
'message': {
565+
'role': 'agent',
566+
'parts': [{'kind': 'text', 'text': 'Hello'}],
567+
'messageId': '111',
568+
'kind': 'message',
569+
'taskId': 'task1',
570+
'contextId': 'session-xyz',
571+
}
572+
},
573+
},
574+
)
575+
576+
# Verify response
577+
assert response.status_code == 200
578+
result = SendMessageResponse.model_validate(response.json())
579+
assert isinstance(result.root, SendMessageSuccessResponse)
580+
assert isinstance(result.root.result, Message)
581+
message = result.root.result
582+
assert isinstance(message.parts[0].root, TextPart)
583+
assert message.parts[0].root.text == 'test_user'
584+
585+
# Verify handler was called
586+
handler.on_message_send.assert_awaited_once()
587+
588+
517589
# === STREAMING TESTS ===
518590

519591

0 commit comments

Comments
 (0)