Skip to content

Commit 2cc2a0d

Browse files
authored
feat: Add a User representation to ServerCallContext (#116)
* Start sketching out auth handling ideas * Introduce simple server request auth structures * Make ServerCallContext a pydantic model * Update check-spelling metadata * Fix broken tests * Delete empty file
1 parent 04c7c45 commit 2cc2a0d

File tree

6 files changed

+92
-30
lines changed

6 files changed

+92
-30
lines changed

.github/actions/spelling/excludes.txt

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# See https://github.com/check-spelling/check-spelling/wiki/Configuration-Examples:-excludes
2+
(?:^|/)(?i).gitignore\E$
3+
(?:^|/)(?i)CODE_OF_CONDUCT.md\E$
24
(?:^|/)(?i)COPYRIGHT
35
(?:^|/)(?i)LICEN[CS]E
4-
(?:^|/)(?i)CODE_OF_CONDUCT.md\E$
5-
(?:^|/)(?i).gitignore\E$
66
(?:^|/)3rdparty/
77
(?:^|/)go\.sum$
88
(?:^|/)package(?:-lock|)\.json$
@@ -33,6 +33,7 @@
3333
\.gif$
3434
\.git-blame-ignore-revs$
3535
\.gitattributes$
36+
\.gitignore\E$
3637
\.gitkeep$
3738
\.graffle$
3839
\.gz$
@@ -62,6 +63,7 @@
6263
\.pyc$
6364
\.pylintrc$
6465
\.qm$
66+
\.ruff.toml$
6567
\.s$
6668
\.sig$
6769
\.so$
@@ -71,6 +73,7 @@
7173
\.tgz$
7274
\.tiff?$
7375
\.ttf$
76+
\.vscode/
7477
\.wav$
7578
\.webm$
7679
\.webp$
@@ -82,8 +85,7 @@
8285
\.zip$
8386
^\.github/actions/spelling/
8487
^\.github/workflows/
85-
\.gitignore\E$
86-
\.vscode/
87-
noxfile.py
88-
\.ruff.toml$
88+
^\Qsrc/a2a/auth/__init__.py\E$
89+
^\Qsrc/a2a/server/request_handlers/context.py\E$
8990
CHANGELOG.md
91+
noxfile.py

.github/actions/spelling/expect.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
AUser
12
excinfo
23
GVsb
34
notif

src/a2a/auth/__init__.py

Whitespace-only changes.

src/a2a/auth/user.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Authenticated user information."""
2+
3+
from abc import ABC, abstractmethod
4+
5+
6+
class User(ABC):
7+
"""A representation of an authenticated user."""
8+
9+
@property
10+
@abstractmethod
11+
def is_authenticated(self) -> bool:
12+
"""Returns whether the current user is authenticated."""
13+
14+
@property
15+
@abstractmethod
16+
def user_name(self) -> str:
17+
"""Returns the user name of the current user."""
18+
19+
20+
class UnauthenticatedUser(User):
21+
"""A representation that no user has been authenticated in the request."""
22+
23+
@property
24+
def is_authenticated(self):
25+
return False
26+
27+
@property
28+
def user_name(self) -> str:
29+
return ''

src/a2a/server/apps/starlette_app.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,53 @@
1+
import contextlib
12
import json
23
import logging
34
import traceback
5+
46
from abc import ABC, abstractmethod
57
from collections.abc import AsyncGenerator
68
from typing import Any
79

810
from pydantic import ValidationError
911
from sse_starlette.sse import EventSourceResponse
1012
from starlette.applications import Starlette
13+
from starlette.authentication import BaseUser
1114
from starlette.requests import Request
1215
from starlette.responses import JSONResponse, Response
1316
from starlette.routing import Route
1417

18+
from a2a.auth.user import UnauthenticatedUser
19+
from a2a.auth.user import User as A2AUser
1520
from a2a.server.context import ServerCallContext
1621
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
1722
from a2a.server.request_handlers.request_handler import RequestHandler
18-
from a2a.types import (A2AError, A2ARequest, AgentCard, CancelTaskRequest,
19-
GetTaskPushNotificationConfigRequest, GetTaskRequest,
20-
InternalError, InvalidRequestError, JSONParseError,
21-
JSONRPCError, JSONRPCErrorResponse, JSONRPCResponse,
22-
SendMessageRequest, SendStreamingMessageRequest,
23-
SendStreamingMessageResponse,
24-
SetTaskPushNotificationConfigRequest,
25-
TaskResubscriptionRequest, UnsupportedOperationError)
23+
from a2a.types import (
24+
A2AError,
25+
A2ARequest,
26+
AgentCard,
27+
CancelTaskRequest,
28+
GetTaskPushNotificationConfigRequest,
29+
GetTaskRequest,
30+
InternalError,
31+
InvalidRequestError,
32+
JSONParseError,
33+
JSONRPCError,
34+
JSONRPCErrorResponse,
35+
JSONRPCResponse,
36+
SendMessageRequest,
37+
SendStreamingMessageRequest,
38+
SendStreamingMessageResponse,
39+
SetTaskPushNotificationConfigRequest,
40+
TaskResubscriptionRequest,
41+
UnsupportedOperationError,
42+
)
2643
from a2a.utils.errors import MethodNotImplementedError
2744

45+
2846
logger = logging.getLogger(__name__)
2947

48+
# Register Starlette User as an implementation of a2a.auth.user.User
49+
A2AUser.register(BaseUser)
50+
3051

3152
class CallContextBuilder(ABC):
3253
"""A class for building ServerCallContexts using the Starlette Request."""
@@ -36,6 +57,18 @@ def build(self, request: Request) -> ServerCallContext:
3657
"""Builds a ServerCallContext from a Starlette Request."""
3758

3859

60+
class DefaultCallContextBuilder(CallContextBuilder):
61+
"""A default implementation of CallContextBuilder."""
62+
63+
def build(self, request: Request) -> ServerCallContext:
64+
user = UnauthenticatedUser()
65+
state = {}
66+
with contextlib.suppress(Exception):
67+
user = request.user
68+
state['auth'] = request.auth
69+
return ServerCallContext(user=user, state=state)
70+
71+
3972
class A2AStarletteApplication:
4073
"""A Starlette application implementing the A2A protocol server endpoints.
4174
@@ -75,7 +108,7 @@ def __init__(
75108
logger.error(
76109
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
77110
)
78-
self._context_builder = context_builder
111+
self._context_builder = context_builder or DefaultCallContextBuilder()
79112

80113
def _generate_error_response(
81114
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -137,11 +170,7 @@ async def _handle_requests(self, request: Request) -> Response:
137170
try:
138171
body = await request.json()
139172
a2a_request = A2ARequest.model_validate(body)
140-
call_context = (
141-
self._context_builder.build(request)
142-
if self._context_builder
143-
else None
144-
)
173+
call_context = self._context_builder.build(request)
145174

146175
request_id = a2a_request.root.id
147176
request_obj = a2a_request.root
@@ -344,7 +373,9 @@ async def _handle_get_authenticated_extended_agent_card(
344373
# extended_agent_card was provided during server initialization,
345374
# return a 404
346375
return JSONResponse(
347-
{'error': 'Authenticated extended agent card is supported but not configured on the server.'},
376+
{
377+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
378+
},
348379
status_code=404,
349380
)
350381

src/a2a/server/context.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,21 @@
33
import collections.abc
44
import typing
55

6+
from pydantic import BaseModel, ConfigDict, Field
7+
8+
from a2a.auth.user import UnauthenticatedUser, User
9+
610

711
State = collections.abc.MutableMapping[str, typing.Any]
812

913

10-
class ServerCallContext:
14+
class ServerCallContext(BaseModel):
1115
"""A context passed when calling a server method.
1216
1317
This class allows storing arbitrary user data in the state attribute.
1418
"""
1519

16-
def __init__(self, state: State | None = None):
17-
if state is None:
18-
state = {}
19-
self._state = state
20+
model_config = ConfigDict(arbitrary_types_allowed=True)
2021

21-
@property
22-
def state(self) -> State:
23-
"""Get the user-provided state."""
24-
return self._state
22+
state: State = Field(default={})
23+
user: User = Field(default=UnauthenticatedUser())

0 commit comments

Comments
 (0)