Skip to content

Commit 88d45eb

Browse files
authored
fix: Correctly adapt starlette BaseUser to A2A User (#133)
1 parent 27e2874 commit 88d45eb

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,20 @@
4545

4646
logger = logging.getLogger(__name__)
4747

48-
# Register Starlette User as an implementation of a2a.auth.user.User
49-
A2AUser.register(BaseUser)
48+
49+
class StarletteUserProxy(A2AUser):
50+
"""Adapts the Starlette User class to the A2A user representation."""
51+
52+
def __init__(self, user: BaseUser):
53+
self._user = user
54+
55+
@property
56+
def is_authenticated(self):
57+
return self._user.is_authenticated
58+
59+
@property
60+
def user_name(self):
61+
return self._user.display_name
5062

5163

5264
class CallContextBuilder(ABC):
@@ -64,7 +76,7 @@ def build(self, request: Request) -> ServerCallContext:
6476
user = UnauthenticatedUser()
6577
state = {}
6678
with contextlib.suppress(Exception):
67-
user = request.user
79+
user = StarletteUserProxy(request.user)
6880
state['auth'] = request.auth
6981
return ServerCallContext(user=user, state=state)
7082

tests/server/test_integration.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55

66
import pytest
77

8+
from starlette.authentication import (
9+
AuthCredentials,
10+
AuthenticationBackend,
11+
BaseUser,
12+
SimpleUser,
13+
)
14+
from starlette.middleware import Middleware
15+
from starlette.middleware.authentication import AuthenticationMiddleware
16+
from starlette.requests import HTTPConnection
817
from starlette.responses import JSONResponse
918
from starlette.routing import Route
1019
from starlette.testclient import TestClient
@@ -18,8 +27,12 @@
1827
InternalError,
1928
InvalidRequestError,
2029
JSONParseError,
30+
Message,
2131
Part,
2232
PushNotificationConfig,
33+
Role,
34+
SendMessageResponse,
35+
SendMessageSuccessResponse,
2336
Task,
2437
TaskArtifactUpdateEvent,
2538
TaskPushNotificationConfig,
@@ -121,9 +134,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock):
121134

122135

123136
@pytest.fixture
124-
def client(app: A2AStarletteApplication):
137+
def client(app: A2AStarletteApplication, **kwargs):
125138
"""Create a test client with the app."""
126-
return TestClient(app.build())
139+
return TestClient(app.build(**kwargs))
127140

128141

129142
# === BASIC FUNCTIONALITY TESTS ===
@@ -249,7 +262,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock):
249262
mock_task = Task(
250263
id='task1',
251264
contextId='session-xyz',
252-
state='completed',
253265
status=task_status,
254266
)
255267
handler.on_message_send.return_value = mock_task
@@ -418,6 +430,67 @@ def test_get_push_notification_config(
418430
handler.on_get_task_push_notification_config.assert_awaited_once()
419431

420432

433+
def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock):
434+
class TestAuthMiddleware(AuthenticationBackend):
435+
async def authenticate(
436+
self, conn: HTTPConnection
437+
) -> tuple[AuthCredentials, BaseUser] | None:
438+
# For the purposes of this test, all requests are authenticated!
439+
return (AuthCredentials(['authenticated']), SimpleUser('test_user'))
440+
441+
client = TestClient(
442+
app.build(
443+
middleware=[
444+
Middleware(
445+
AuthenticationMiddleware, backend=TestAuthMiddleware()
446+
)
447+
]
448+
)
449+
)
450+
451+
# Set the output message to be the authenticated user name
452+
handler.on_message_send.side_effect = lambda params, context: Message(
453+
contextId='session-xyz',
454+
messageId='112',
455+
role=Role.agent,
456+
parts=[
457+
Part(TextPart(text=context.user.user_name)),
458+
],
459+
)
460+
461+
# Send request
462+
response = client.post(
463+
'/',
464+
json={
465+
'jsonrpc': '2.0',
466+
'id': '123',
467+
'method': 'message/send',
468+
'params': {
469+
'message': {
470+
'role': 'agent',
471+
'parts': [{'kind': 'text', 'text': 'Hello'}],
472+
'messageId': '111',
473+
'kind': 'message',
474+
'taskId': 'task1',
475+
'contextId': 'session-xyz',
476+
}
477+
},
478+
},
479+
)
480+
481+
# Verify response
482+
assert response.status_code == 200
483+
result = SendMessageResponse.model_validate(response.json())
484+
assert isinstance(result.root, SendMessageSuccessResponse)
485+
assert isinstance(result.root.result, Message)
486+
message = result.root.result
487+
assert isinstance(message.parts[0].root, TextPart)
488+
assert message.parts[0].root.text == 'test_user'
489+
490+
# Verify handler was called
491+
handler.on_message_send.assert_awaited_once()
492+
493+
421494
# === STREAMING TESTS ===
422495

423496

0 commit comments

Comments
 (0)