33from unittest import mock
44
55import pytest
6+ from starlette .authentication import (
7+ AuthCredentials ,
8+ AuthenticationBackend ,
9+ BaseUser ,
10+ SimpleUser ,
11+ )
12+ from starlette .middleware import Middleware
13+ from starlette .middleware .authentication import AuthenticationMiddleware
14+ from starlette .requests import HTTPConnection
615from starlette .responses import JSONResponse
716from starlette .routing import Route
817from starlette .testclient import TestClient
918
1019from a2a .server .apps .starlette_app import A2AStarletteApplication
11- from a2a .types import (AgentCapabilities , AgentCard , Artifact , DataPart ,
12- InternalError , InvalidRequestError , JSONParseError ,
13- Part , PushNotificationConfig , Task ,
14- TaskArtifactUpdateEvent , TaskPushNotificationConfig ,
15- TaskState , TaskStatus , TextPart ,
16- UnsupportedOperationError )
20+ from a2a .types import (
21+ AgentCapabilities ,
22+ AgentCard ,
23+ Artifact ,
24+ DataPart ,
25+ InternalError ,
26+ InvalidRequestError ,
27+ JSONParseError ,
28+ Message ,
29+ Part ,
30+ PushNotificationConfig ,
31+ Role ,
32+ SendMessageResponse ,
33+ SendMessageSuccessResponse ,
34+ Task ,
35+ TaskArtifactUpdateEvent ,
36+ TaskPushNotificationConfig ,
37+ TaskState ,
38+ TaskStatus ,
39+ TextPart ,
40+ UnsupportedOperationError ,
41+ )
1742from a2a .utils .errors import MethodNotImplementedError
1843
1944# === TEST SETUP ===
@@ -106,9 +131,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock):
106131
107132
108133@pytest .fixture
109- def client (app : A2AStarletteApplication ):
134+ def client (app : A2AStarletteApplication , ** kwargs ):
110135 """Create a test client with the app."""
111- return TestClient (app .build ())
136+ return TestClient (app .build (** kwargs ))
112137
113138
114139# === BASIC FUNCTIONALITY TESTS ===
@@ -135,7 +160,7 @@ def test_authenticated_extended_agent_card_endpoint_not_supported(
135160 # So, building the app and trying to hit it should result in 404 from Starlette itself
136161 client = TestClient (app_instance .build ())
137162 response = client .get ('/agent/authenticatedExtendedCard' )
138- assert response .status_code == 404 # Starlette's default for no route
163+ assert response .status_code == 404 # Starlette's default for no route
139164
140165
141166def test_authenticated_extended_agent_card_endpoint_supported_with_specific_extended_card (
@@ -144,7 +169,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte
144169 handler : mock .AsyncMock ,
145170):
146171 """Test extended card endpoint returns the specific extended card when provided."""
147- agent_card .supportsAuthenticatedExtendedCard = True # Main card must support it
172+ agent_card .supportsAuthenticatedExtendedCard = (
173+ True # Main card must support it
174+ )
148175 app_instance = A2AStarletteApplication (
149176 agent_card , handler , extended_agent_card = extended_agent_card_fixture
150177 )
@@ -157,10 +184,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte
157184 assert data ['name' ] == extended_agent_card_fixture .name
158185 assert data ['version' ] == extended_agent_card_fixture .version
159186 assert len (data ['skills' ]) == len (extended_agent_card_fixture .skills )
160- assert any (
161- skill ['id' ] == 'skill-extended' for skill in data ['skills' ]
162- ), "Extended skill not found in served card"
163-
187+ assert any (skill ['id' ] == 'skill-extended' for skill in data ['skills' ]), (
188+ 'Extended skill not found in served card'
189+ )
164190
165191
166192def test_agent_card_custom_url (
@@ -233,7 +259,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock):
233259 mock_task = Task (
234260 id = 'task1' ,
235261 contextId = 'session-xyz' ,
236- state = 'completed' ,
237262 status = task_status ,
238263 )
239264 handler .on_message_send .return_value = mock_task
@@ -402,6 +427,67 @@ def test_get_push_notification_config(
402427 handler .on_get_task_push_notification_config .assert_awaited_once ()
403428
404429
430+ def test_server_auth (app : A2AStarletteApplication , handler : mock .AsyncMock ):
431+ class TestAuthMiddleware (AuthenticationBackend ):
432+ async def authenticate (
433+ self , conn : HTTPConnection
434+ ) -> tuple [AuthCredentials , BaseUser ] | None :
435+ # For the purposes of this test, all requests are authenticated!
436+ return (AuthCredentials (['authenticated' ]), SimpleUser ('test_user' ))
437+
438+ client = TestClient (
439+ app .build (
440+ middleware = [
441+ Middleware (
442+ AuthenticationMiddleware , backend = TestAuthMiddleware ()
443+ )
444+ ]
445+ )
446+ )
447+
448+ # Set the output message to be the authenticated user name
449+ handler .on_message_send .side_effect = lambda params , context : Message (
450+ contextId = 'session-xyz' ,
451+ messageId = '112' ,
452+ role = Role .agent ,
453+ parts = [
454+ Part (TextPart (text = context .user .user_name )),
455+ ],
456+ )
457+
458+ # Send request
459+ response = client .post (
460+ '/' ,
461+ json = {
462+ 'jsonrpc' : '2.0' ,
463+ 'id' : '123' ,
464+ 'method' : 'message/send' ,
465+ 'params' : {
466+ 'message' : {
467+ 'role' : 'agent' ,
468+ 'parts' : [{'kind' : 'text' , 'text' : 'Hello' }],
469+ 'messageId' : '111' ,
470+ 'kind' : 'message' ,
471+ 'taskId' : 'task1' ,
472+ 'contextId' : 'session-xyz' ,
473+ }
474+ },
475+ },
476+ )
477+
478+ # Verify response
479+ assert response .status_code == 200
480+ result = SendMessageResponse .model_validate (response .json ())
481+ assert isinstance (result .root , SendMessageSuccessResponse )
482+ assert isinstance (result .root .result , Message )
483+ message = result .root .result
484+ assert isinstance (message .parts [0 ].root , TextPart )
485+ assert message .parts [0 ].root .text == 'test_user'
486+
487+ # Verify handler was called
488+ handler .on_message_send .assert_awaited_once ()
489+
490+
405491# === STREAMING TESTS ===
406492
407493
0 commit comments