|
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 |
|
| 8 | +from starlette.authentication import ( |
| 9 | + AuthCredentials, |
| 10 | + AuthenticationBackend, |
| 11 | + BaseUser, |
| 12 | + SimpleUser, |
| 13 | +) |
| 14 | +from starlette.middleware import Middleware |
| 15 | +from starlette.middleware.authentication import AuthenticationMiddleware |
| 16 | +from starlette.requests import HTTPConnection |
8 | 17 | from starlette.responses import JSONResponse |
9 | 18 | from starlette.routing import Route |
10 | 19 | from starlette.testclient import TestClient |
|
21 | 30 | InternalError, |
22 | 31 | InvalidRequestError, |
23 | 32 | JSONParseError, |
| 33 | + Message, |
24 | 34 | Part, |
25 | 35 | PushNotificationConfig, |
| 36 | + Role, |
| 37 | + SendMessageResponse, |
| 38 | + SendMessageSuccessResponse, |
26 | 39 | Task, |
27 | 40 | TaskArtifactUpdateEvent, |
28 | 41 | TaskPushNotificationConfig, |
@@ -122,21 +135,20 @@ def handler(): |
122 | 135 | def starlette_app(agent_card: AgentCard, handler: mock.AsyncMock): |
123 | 136 | return A2AStarletteApplication(agent_card, handler) |
124 | 137 |
|
125 | | - |
126 | 138 | @pytest.fixture |
127 | | -def starlette_client(app: A2AStarletteApplication): |
| 139 | +def starlette_client(app: A2AStarletteApplication, **kwargs): |
128 | 140 | """Create a test client with the Starlette app.""" |
129 | | - return TestClient(app.build()) |
| 141 | + return TestClient(app.build(**kwargs)) |
| 142 | + |
130 | 143 |
|
131 | 144 | @pytest.fixture |
132 | 145 | def fastapi_app(agent_card: AgentCard, handler: mock.AsyncMock): |
133 | 146 | return A2AFastAPIApplication(agent_card, handler) |
134 | 147 |
|
135 | | - |
136 | 148 | @pytest.fixture |
137 | | -def fastapi_client(app: A2AFastAPIApplication): |
| 149 | +def fastapi_client(app: A2AFastAPIApplication, **kwargs): |
138 | 150 | """Create a test client with the FastAPI app.""" |
139 | | - return TestClient(app.build()) |
| 151 | + return TestClient(app.build(**kwargs)) |
140 | 152 |
|
141 | 153 |
|
142 | 154 | # === BASIC FUNCTIONALITY TESTS === |
@@ -345,7 +357,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock): |
345 | 357 | mock_task = Task( |
346 | 358 | id='task1', |
347 | 359 | contextId='session-xyz', |
348 | | - state='completed', |
349 | 360 | status=task_status, |
350 | 361 | ) |
351 | 362 | handler.on_message_send.return_value = mock_task |
@@ -514,6 +525,67 @@ def test_get_push_notification_config( |
514 | 525 | handler.on_get_task_push_notification_config.assert_awaited_once() |
515 | 526 |
|
516 | 527 |
|
| 528 | +def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): |
| 529 | + class TestAuthMiddleware(AuthenticationBackend): |
| 530 | + async def authenticate( |
| 531 | + self, conn: HTTPConnection |
| 532 | + ) -> tuple[AuthCredentials, BaseUser] | None: |
| 533 | + # For the purposes of this test, all requests are authenticated! |
| 534 | + return (AuthCredentials(['authenticated']), SimpleUser('test_user')) |
| 535 | + |
| 536 | + client = TestClient( |
| 537 | + app.build( |
| 538 | + middleware=[ |
| 539 | + Middleware( |
| 540 | + AuthenticationMiddleware, backend=TestAuthMiddleware() |
| 541 | + ) |
| 542 | + ] |
| 543 | + ) |
| 544 | + ) |
| 545 | + |
| 546 | + # Set the output message to be the authenticated user name |
| 547 | + handler.on_message_send.side_effect = lambda params, context: Message( |
| 548 | + contextId='session-xyz', |
| 549 | + messageId='112', |
| 550 | + role=Role.agent, |
| 551 | + parts=[ |
| 552 | + Part(TextPart(text=context.user.user_name)), |
| 553 | + ], |
| 554 | + ) |
| 555 | + |
| 556 | + # Send request |
| 557 | + response = client.post( |
| 558 | + '/', |
| 559 | + json={ |
| 560 | + 'jsonrpc': '2.0', |
| 561 | + 'id': '123', |
| 562 | + 'method': 'message/send', |
| 563 | + 'params': { |
| 564 | + 'message': { |
| 565 | + 'role': 'agent', |
| 566 | + 'parts': [{'kind': 'text', 'text': 'Hello'}], |
| 567 | + 'messageId': '111', |
| 568 | + 'kind': 'message', |
| 569 | + 'taskId': 'task1', |
| 570 | + 'contextId': 'session-xyz', |
| 571 | + } |
| 572 | + }, |
| 573 | + }, |
| 574 | + ) |
| 575 | + |
| 576 | + # Verify response |
| 577 | + assert response.status_code == 200 |
| 578 | + result = SendMessageResponse.model_validate(response.json()) |
| 579 | + assert isinstance(result.root, SendMessageSuccessResponse) |
| 580 | + assert isinstance(result.root.result, Message) |
| 581 | + message = result.root.result |
| 582 | + assert isinstance(message.parts[0].root, TextPart) |
| 583 | + assert message.parts[0].root.text == 'test_user' |
| 584 | + |
| 585 | + # Verify handler was called |
| 586 | + handler.on_message_send.assert_awaited_once() |
| 587 | + |
| 588 | + |
517 | 589 | # === STREAMING TESTS === |
518 | 590 |
|
519 | 591 |
|
|
0 commit comments