Skip to content

Commit 45202a5

Browse files
committed
Introduce simple server request auth structures
1 parent 17c96dc commit 45202a5

File tree

3 files changed

+83
-22
lines changed

3 files changed

+83
-22
lines changed

src/a2a/auth/user.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,29 @@
1-
class User:
1+
"""Authenticated user information."""
2+
3+
from abc import ABC, abstractmethod
4+
5+
6+
class User(ABC):
27
"""A representation of an authenticated user."""
3-
def __init__(self):
4-
pass
8+
9+
@abstractmethod
10+
@property
11+
def is_authenticated(self) -> bool:
12+
"""Returns whether the current user is authenticated."""
13+
14+
@abstractmethod
15+
@property
16+
def user_name(self) -> str:
17+
"""Returns the user name of the current user."""
18+
19+
20+
class UnauthenticatedUser(User):
21+
"""A representation that no user has been authenticated in the request."""
22+
23+
@property
24+
def is_authenticated(self):
25+
return False
26+
27+
@property
28+
def user_name(self) -> str:
29+
return ''

src/a2a/server/apps/starlette_app.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,52 @@
1+
import contextlib
12
import json
23
import logging
34
import traceback
5+
46
from abc import ABC, abstractmethod
57
from collections.abc import AsyncGenerator
68
from typing import Any
79

810
from pydantic import ValidationError
911
from sse_starlette.sse import EventSourceResponse
1012
from starlette.applications import Starlette
13+
from starlette.authentication import BaseUser
1114
from starlette.requests import Request
1215
from starlette.responses import JSONResponse, Response
1316
from starlette.routing import Route
1417

18+
from a2a.auth.user import User as A2AUser
1519
from a2a.server.context import ServerCallContext
1620
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
1721
from a2a.server.request_handlers.request_handler import RequestHandler
18-
from a2a.types import (A2AError, A2ARequest, AgentCard, CancelTaskRequest,
19-
GetTaskPushNotificationConfigRequest, GetTaskRequest,
20-
InternalError, InvalidRequestError, JSONParseError,
21-
JSONRPCError, JSONRPCErrorResponse, JSONRPCResponse,
22-
SendMessageRequest, SendStreamingMessageRequest,
23-
SendStreamingMessageResponse,
24-
SetTaskPushNotificationConfigRequest,
25-
TaskResubscriptionRequest, UnsupportedOperationError)
22+
from a2a.types import (
23+
A2AError,
24+
A2ARequest,
25+
AgentCard,
26+
CancelTaskRequest,
27+
GetTaskPushNotificationConfigRequest,
28+
GetTaskRequest,
29+
InternalError,
30+
InvalidRequestError,
31+
JSONParseError,
32+
JSONRPCError,
33+
JSONRPCErrorResponse,
34+
JSONRPCResponse,
35+
SendMessageRequest,
36+
SendStreamingMessageRequest,
37+
SendStreamingMessageResponse,
38+
SetTaskPushNotificationConfigRequest,
39+
TaskResubscriptionRequest,
40+
UnsupportedOperationError,
41+
)
2642
from a2a.utils.errors import MethodNotImplementedError
2743

44+
2845
logger = logging.getLogger(__name__)
2946

47+
# Register Starlette User as an implementation of a2a.auth.user.User
48+
A2AUser.register(BaseUser)
49+
3050

3151
class CallContextBuilder(ABC):
3252
"""A class for building ServerCallContexts using the Starlette Request."""
@@ -36,6 +56,18 @@ def build(self, request: Request) -> ServerCallContext:
3656
"""Builds a ServerCallContext from a Starlette Request."""
3757

3858

59+
class DefaultCallContextBuilder(CallContextBuilder):
60+
"""A default implementation of CallContextBuilder."""
61+
62+
def build(self, request: Request) -> ServerCallContext:
63+
user = None
64+
state = {}
65+
with contextlib.suppress(Exception):
66+
user = request.user
67+
state['auth'] = request.auth
68+
return ServerCallContext(user=user, state=state)
69+
70+
3971
class A2AStarletteApplication:
4072
"""A Starlette application implementing the A2A protocol server endpoints.
4173
@@ -75,7 +107,7 @@ def __init__(
75107
logger.error(
76108
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
77109
)
78-
self._context_builder = context_builder
110+
self._context_builder = context_builder or DefaultCallContextBuilder()
79111

80112
def _generate_error_response(
81113
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -137,11 +169,7 @@ async def _handle_requests(self, request: Request) -> Response:
137169
try:
138170
body = await request.json()
139171
a2a_request = A2ARequest.model_validate(body)
140-
call_context = (
141-
self._context_builder.build(request)
142-
if self._context_builder
143-
else None
144-
)
172+
call_context = self._context_builder.build(request)
145173

146174
request_id = a2a_request.root.id
147175
request_obj = a2a_request.root
@@ -344,7 +372,9 @@ async def _handle_get_authenticated_extended_agent_card(
344372
# extended_agent_card was provided during server initialization,
345373
# return a 404
346374
return JSONResponse(
347-
{'error': 'Authenticated extended agent card is supported but not configured on the server.'},
375+
{
376+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
377+
},
348378
status_code=404,
349379
)
350380

src/a2a/server/context.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import collections.abc
44
import typing
55

6+
from a2a.auth.user import UnauthenticatedUser, User
7+
68

79
State = collections.abc.MutableMapping[str, typing.Any]
810

@@ -13,10 +15,14 @@ class ServerCallContext:
1315
This class allows storing arbitrary user data in the state attribute.
1416
"""
1517

16-
def __init__(self, state: State | None = None):
17-
if state is None:
18-
state = {}
19-
self._state = state
18+
def __init__(self, state: State | None = None, user: User | None = None):
19+
self._state = state or {}
20+
self._user = user or UnauthenticatedUser()
21+
22+
@property
23+
def user(self) -> User:
24+
"""Get the user associated with this context, or UnauthenticatedUser."""
25+
return self._user
2026

2127
@property
2228
def state(self) -> State:

0 commit comments

Comments
 (0)