|
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 |
|
| 8 | +from starlette.authentication import ( |
| 9 | + AuthCredentials, |
| 10 | + AuthenticationBackend, |
| 11 | + BaseUser, |
| 12 | + SimpleUser, |
| 13 | +) |
| 14 | +from starlette.middleware import Middleware |
| 15 | +from starlette.middleware.authentication import AuthenticationMiddleware |
| 16 | +from starlette.requests import HTTPConnection |
8 | 17 | from starlette.responses import JSONResponse |
9 | 18 | from starlette.routing import Route |
10 | 19 | from starlette.testclient import TestClient |
|
18 | 27 | InternalError, |
19 | 28 | InvalidRequestError, |
20 | 29 | JSONParseError, |
| 30 | + Message, |
21 | 31 | Part, |
22 | 32 | PushNotificationConfig, |
| 33 | + Role, |
| 34 | + SendMessageResponse, |
| 35 | + SendMessageSuccessResponse, |
23 | 36 | Task, |
24 | 37 | TaskArtifactUpdateEvent, |
25 | 38 | TaskPushNotificationConfig, |
@@ -121,9 +134,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock): |
121 | 134 |
|
122 | 135 |
|
123 | 136 | @pytest.fixture |
124 | | -def client(app: A2AStarletteApplication): |
| 137 | +def client(app: A2AStarletteApplication, **kwargs): |
125 | 138 | """Create a test client with the app.""" |
126 | | - return TestClient(app.build()) |
| 139 | + return TestClient(app.build(**kwargs)) |
127 | 140 |
|
128 | 141 |
|
129 | 142 | # === BASIC FUNCTIONALITY TESTS === |
@@ -249,7 +262,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock): |
249 | 262 | mock_task = Task( |
250 | 263 | id='task1', |
251 | 264 | contextId='session-xyz', |
252 | | - state='completed', |
253 | 265 | status=task_status, |
254 | 266 | ) |
255 | 267 | handler.on_message_send.return_value = mock_task |
@@ -418,6 +430,67 @@ def test_get_push_notification_config( |
418 | 430 | handler.on_get_task_push_notification_config.assert_awaited_once() |
419 | 431 |
|
420 | 432 |
|
| 433 | +def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): |
| 434 | + class TestAuthMiddleware(AuthenticationBackend): |
| 435 | + async def authenticate( |
| 436 | + self, conn: HTTPConnection |
| 437 | + ) -> tuple[AuthCredentials, BaseUser] | None: |
| 438 | + # For the purposes of this test, all requests are authenticated! |
| 439 | + return (AuthCredentials(['authenticated']), SimpleUser('test_user')) |
| 440 | + |
| 441 | + client = TestClient( |
| 442 | + app.build( |
| 443 | + middleware=[ |
| 444 | + Middleware( |
| 445 | + AuthenticationMiddleware, backend=TestAuthMiddleware() |
| 446 | + ) |
| 447 | + ] |
| 448 | + ) |
| 449 | + ) |
| 450 | + |
| 451 | + # Set the output message to be the authenticated user name |
| 452 | + handler.on_message_send.side_effect = lambda params, context: Message( |
| 453 | + contextId='session-xyz', |
| 454 | + messageId='112', |
| 455 | + role=Role.agent, |
| 456 | + parts=[ |
| 457 | + Part(TextPart(text=context.user.user_name)), |
| 458 | + ], |
| 459 | + ) |
| 460 | + |
| 461 | + # Send request |
| 462 | + response = client.post( |
| 463 | + '/', |
| 464 | + json={ |
| 465 | + 'jsonrpc': '2.0', |
| 466 | + 'id': '123', |
| 467 | + 'method': 'message/send', |
| 468 | + 'params': { |
| 469 | + 'message': { |
| 470 | + 'role': 'agent', |
| 471 | + 'parts': [{'kind': 'text', 'text': 'Hello'}], |
| 472 | + 'messageId': '111', |
| 473 | + 'kind': 'message', |
| 474 | + 'taskId': 'task1', |
| 475 | + 'contextId': 'session-xyz', |
| 476 | + } |
| 477 | + }, |
| 478 | + }, |
| 479 | + ) |
| 480 | + |
| 481 | + # Verify response |
| 482 | + assert response.status_code == 200 |
| 483 | + result = SendMessageResponse.model_validate(response.json()) |
| 484 | + assert isinstance(result.root, SendMessageSuccessResponse) |
| 485 | + assert isinstance(result.root.result, Message) |
| 486 | + message = result.root.result |
| 487 | + assert isinstance(message.parts[0].root, TextPart) |
| 488 | + assert message.parts[0].root.text == 'test_user' |
| 489 | + |
| 490 | + # Verify handler was called |
| 491 | + handler.on_message_send.assert_awaited_once() |
| 492 | + |
| 493 | + |
421 | 494 | # === STREAMING TESTS === |
422 | 495 |
|
423 | 496 |
|
|
0 commit comments