Skip to content

Commit 1b2dc1e

Browse files
committed
fix: type issues
1 parent aa159f3 commit 1b2dc1e

14 files changed

+224
-129
lines changed

tests/client/test_auth_middleware.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from a2a.client.auth.credentials import InMemoryContextCredentialStore
12
import json
23

34
from collections.abc import Callable
45
from dataclasses import dataclass
5-
from typing import Any
6+
from typing import Any, Generator
67

78
import httpx
89
import pytest
@@ -100,13 +101,15 @@ async def send_message(
100101

101102

102103
@pytest.fixture
103-
def store():
104+
def store() -> Generator[InMemoryContextCredentialStore, Any, None]:
104105
store = InMemoryContextCredentialStore()
105106
yield store
106107

107108

108109
@pytest.mark.asyncio
109-
async def test_auth_interceptor_skips_when_no_agent_card(store):
110+
async def test_auth_interceptor_skips_when_no_agent_card(
111+
store: InMemoryContextCredentialStore,
112+
):
110113
"""
111114
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
112115
"""
@@ -126,7 +129,9 @@ async def test_auth_interceptor_skips_when_no_agent_card(store):
126129

127130

128131
@pytest.mark.asyncio
129-
async def test_in_memory_context_credential_store(store):
132+
async def test_in_memory_context_credential_store(
133+
store: InMemoryContextCredentialStore,
134+
):
130135
"""
131136
Verifies that InMemoryContextCredentialStore correctly stores and retrieves
132137
credentials based on the session ID in the client context.
@@ -284,7 +289,9 @@ class AuthTestCase:
284289
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
285290
)
286291
@respx.mock
287-
async def test_auth_interceptor_variants(test_case, store):
292+
async def test_auth_interceptor_variants(
293+
test_case: AuthTestCase, store: InMemoryContextCredentialStore
294+
):
288295
"""
289296
Parametrized test verifying that AuthInterceptor correctly attaches credentials
290297
based on the defined security scheme in the AgentCard.
@@ -329,7 +336,7 @@ async def test_auth_interceptor_variants(test_case, store):
329336

330337
@pytest.mark.asyncio
331338
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
332-
store,
339+
store: InMemoryContextCredentialStore,
333340
):
334341
"""
335342
Tests that AuthInterceptor skips a scheme if it's listed in security requirements

tests/client/test_base_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020

2121
@pytest.fixture
22-
def mock_transport():
22+
def mock_transport() -> AsyncMock:
2323
return AsyncMock(spec=ClientTransport)
2424

2525

2626
@pytest.fixture
27-
def sample_agent_card():
27+
def sample_agent_card() -> AgentCard:
2828
return AgentCard(
2929
name='Test Agent',
3030
description='An agent for testing',
@@ -38,7 +38,7 @@ def sample_agent_card():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
role=Role.user,
4444
message_id='msg-1',
@@ -47,7 +47,7 @@ def sample_message():
4747

4848

4949
@pytest.fixture
50-
def base_client(sample_agent_card, mock_transport):
50+
def base_client(sample_agent_card: AgentCard, mock_transport: AsyncMock):
5151
config = ClientConfig(streaming=True)
5252
return BaseClient(
5353
card=sample_agent_card,

tests/client/test_client_task_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323

2424
@pytest.fixture
25-
def task_manager():
25+
def task_manager() -> ClientTaskManager:
2626
return ClientTaskManager()
2727

2828

2929
@pytest.fixture
30-
def sample_task():
30+
def sample_task() -> Task:
3131
return Task(
3232
id='task123',
3333
context_id='context456',
@@ -38,7 +38,7 @@ def sample_task():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
message_id='msg1',
4444
role=Role.user,

tests/client/test_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_raising_base_error(self):
178178
(500, 'Server Error', 'HTTP Error 500: Server Error'),
179179
],
180180
)
181-
def test_http_error_parametrized(status_code, message, expected):
181+
def test_http_error_parametrized(status_code: int, message: str, expected: str):
182182
"""Parametrized test for HTTP errors with different status codes."""
183183
error = A2AClientHTTPError(status_code, message)
184184
assert error.status_code == status_code
@@ -194,7 +194,7 @@ def test_http_error_parametrized(status_code, message, expected):
194194
('Parsing failed', 'JSON Error: Parsing failed'),
195195
],
196196
)
197-
def test_json_error_parametrized(message, expected):
197+
def test_json_error_parametrized(message: str, expected: str):
198198
"""Parametrized test for JSON errors with different messages."""
199199
error = A2AClientJSONError(message)
200200
assert error.message == message

tests/client/test_grpc_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def sample_task_status_update_event() -> TaskStatusUpdateEvent:
128128

129129
@pytest.fixture
130130
def sample_task_artifact_update_event(
131-
sample_artifact,
131+
sample_artifact: Artifact,
132132
) -> TaskArtifactUpdateEvent:
133133
"""Provides a sample TaskArtifactUpdateEvent."""
134134
return TaskArtifactUpdateEvent(

tests/server/agent_execution/test_context.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,22 @@ class TestRequestContext:
1919
"""Tests for the RequestContext class."""
2020

2121
@pytest.fixture
22-
def mock_message(self):
22+
def mock_message(self) -> Mock:
2323
"""Fixture for a mock Message."""
24+
2425
return Mock(spec=Message, task_id=None, context_id=None)
2526

2627
@pytest.fixture
27-
def mock_params(self, mock_message):
28+
def mock_params(self, mock_message: Mock) -> Mock:
2829
"""Fixture for a mock MessageSendParams."""
2930
return Mock(spec=MessageSendParams, message=mock_message)
3031

3132
@pytest.fixture
32-
def mock_task(self):
33+
def mock_task(self) -> Mock:
3334
"""Fixture for a mock Task."""
3435
return Mock(spec=Task, id='task-123', context_id='context-456')
3536

36-
def test_init_without_params(self):
37+
def test_init_without_params(self) -> None:
3738
"""Test initialization without parameters."""
3839
context = RequestContext()
3940
assert context.message is None
@@ -42,7 +43,7 @@ def test_init_without_params(self):
4243
assert context.current_task is None
4344
assert context.related_tasks == []
4445

45-
def test_init_with_params_no_ids(self, mock_params):
46+
def test_init_with_params_no_ids(self, mock_params: Mock) -> None:
4647
"""Test initialization with params but no task or context IDs."""
4748
with patch(
4849
'uuid.uuid4',
@@ -65,23 +66,23 @@ def test_init_with_params_no_ids(self, mock_params):
6566
== '00000000-0000-0000-0000-000000000002'
6667
)
6768

68-
def test_init_with_task_id(self, mock_params):
69+
def test_init_with_task_id(self, mock_params: Mock) -> None:
6970
"""Test initialization with task ID provided."""
7071
task_id = 'task-123'
7172
context = RequestContext(request=mock_params, task_id=task_id)
7273

7374
assert context.task_id == task_id
7475
assert mock_params.message.task_id == task_id
7576

76-
def test_init_with_context_id(self, mock_params):
77+
def test_init_with_context_id(self, mock_params: Mock) -> None:
7778
"""Test initialization with context ID provided."""
7879
context_id = 'context-456'
7980
context = RequestContext(request=mock_params, context_id=context_id)
8081

8182
assert context.context_id == context_id
8283
assert mock_params.message.context_id == context_id
8384

84-
def test_init_with_both_ids(self, mock_params):
85+
def test_init_with_both_ids(self, mock_params: Mock) -> None:
8586
"""Test initialization with both task and context IDs provided."""
8687
task_id = 'task-123'
8788
context_id = 'context-456'
@@ -94,7 +95,7 @@ def test_init_with_both_ids(self, mock_params):
9495
assert context.context_id == context_id
9596
assert mock_params.message.context_id == context_id
9697

97-
def test_init_with_task(self, mock_params, mock_task):
98+
def test_init_with_task(self, mock_params: Mock, mock_task: Mock) -> None:
9899
"""Test initialization with a task object."""
99100
context = RequestContext(request=mock_params, task=mock_task)
100101

@@ -105,7 +106,7 @@ def test_get_user_input_no_params(self):
105106
context = RequestContext()
106107
assert context.get_user_input() == ''
107108

108-
def test_attach_related_task(self, mock_task):
109+
def test_attach_related_task(self, mock_task: Mock):
109110
"""Test attach_related_task adds a task to related_tasks."""
110111
context = RequestContext()
111112
assert len(context.related_tasks) == 0
@@ -120,7 +121,7 @@ def test_attach_related_task(self, mock_task):
120121
assert len(context.related_tasks) == 2
121122
assert context.related_tasks[1] == another_task
122123

123-
def test_current_task_property(self, mock_task):
124+
def test_current_task_property(self, mock_task: Mock) -> None:
124125
"""Test current_task getter and setter."""
125126
context = RequestContext()
126127
assert context.current_task is None
@@ -133,13 +134,15 @@ def test_current_task_property(self, mock_task):
133134
context.current_task = new_task
134135
assert context.current_task == new_task
135136

136-
def test_check_or_generate_task_id_no_params(self):
137+
def test_check_or_generate_task_id_no_params(self) -> None:
137138
"""Test _check_or_generate_task_id with no params does nothing."""
138139
context = RequestContext()
139140
context._check_or_generate_task_id()
140141
assert context.task_id is None
141142

142-
def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
143+
def test_check_or_generate_task_id_with_existing_task_id(
144+
self, mock_params: Mock
145+
) -> None:
143146
"""Test _check_or_generate_task_id with existing task ID."""
144147
existing_id = 'existing-task-id'
145148
mock_params.message.task_id = existing_id
@@ -151,8 +154,8 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
151154
assert mock_params.message.task_id == existing_id
152155

153156
def test_check_or_generate_task_id_with_custom_id_generator(
154-
self, mock_params
155-
):
157+
self, mock_params: Mock
158+
) -> None:
156159
"""Test _check_or_generate_task_id uses custom ID generator when provided."""
157160
id_generator = Mock(spec=IDGenerator)
158161
id_generator.generate.return_value = 'custom-task-id'
@@ -164,14 +167,14 @@ def test_check_or_generate_task_id_with_custom_id_generator(
164167

165168
assert context.task_id == 'custom-task-id'
166169

167-
def test_check_or_generate_context_id_no_params(self):
170+
def test_check_or_generate_context_id_no_params(self) -> None:
168171
"""Test _check_or_generate_context_id with no params does nothing."""
169172
context = RequestContext()
170173
context._check_or_generate_context_id()
171174
assert context.context_id is None
172175

173176
def test_check_or_generate_context_id_with_existing_context_id(
174-
self, mock_params
177+
self, mock_params: Mock
175178
):
176179
"""Test _check_or_generate_context_id with existing context ID."""
177180
existing_id = 'existing-context-id'
@@ -184,8 +187,8 @@ def test_check_or_generate_context_id_with_existing_context_id(
184187
assert mock_params.message.context_id == existing_id
185188

186189
def test_check_or_generate_context_id_with_custom_id_generator(
187-
self, mock_params
188-
):
190+
self, mock_params: Mock
191+
) -> None:
189192
"""Test _check_or_generate_context_id uses custom ID generator when provided."""
190193
id_generator = Mock(spec=IDGenerator)
191194
id_generator.generate.return_value = 'custom-context-id'
@@ -198,8 +201,8 @@ def test_check_or_generate_context_id_with_custom_id_generator(
198201
assert context.context_id == 'custom-context-id'
199202

200203
def test_init_raises_error_on_task_id_mismatch(
201-
self, mock_params, mock_task
202-
):
204+
self, mock_params: Mock, mock_task: Mock
205+
) -> None:
203206
"""Test that an error is raised if provided task_id mismatches task.id."""
204207
with pytest.raises(ServerError) as exc_info:
205208
RequestContext(
@@ -208,8 +211,8 @@ def test_init_raises_error_on_task_id_mismatch(
208211
assert 'bad task id' in str(exc_info.value.error.message)
209212

210213
def test_init_raises_error_on_context_id_mismatch(
211-
self, mock_params, mock_task
212-
):
214+
self, mock_params: Mock, mock_task: Mock
215+
) -> None:
213216
"""Test that an error is raised if provided context_id mismatches task.context_id."""
214217
# Set a valid task_id to avoid that error
215218
mock_params.message.task_id = mock_task.id
@@ -224,36 +227,38 @@ def test_init_raises_error_on_context_id_mismatch(
224227

225228
assert 'bad context id' in str(exc_info.value.error.message)
226229

227-
def test_with_related_tasks_provided(self, mock_task):
230+
def test_with_related_tasks_provided(self, mock_task: Mock) -> None:
228231
"""Test initialization with related tasks provided."""
229232
related_tasks = [mock_task, Mock(spec=Task)]
230233
context = RequestContext(related_tasks=related_tasks)
231234

232235
assert context.related_tasks == related_tasks
233236
assert len(context.related_tasks) == 2
234237

235-
def test_message_property_without_params(self):
238+
def test_message_property_without_params(self) -> None:
236239
"""Test message property returns None when no params are provided."""
237240
context = RequestContext()
238241
assert context.message is None
239242

240-
def test_message_property_with_params(self, mock_params):
243+
def test_message_property_with_params(self, mock_params: Mock) -> None:
241244
"""Test message property returns the message from params."""
242245
context = RequestContext(request=mock_params)
243246
assert context.message == mock_params.message
244247

245-
def test_metadata_property_without_content(self):
248+
def test_metadata_property_without_content(self) -> None:
246249
"""Test metadata property returns empty dict when no content are provided."""
247250
context = RequestContext()
248251
assert context.metadata == {}
249252

250-
def test_metadata_property_with_content(self, mock_params):
253+
def test_metadata_property_with_content(self, mock_params: Mock) -> None:
251254
"""Test metadata property returns the metadata from params."""
252255
mock_params.metadata = {'key': 'value'}
253256
context = RequestContext(request=mock_params)
254257
assert context.metadata == {'key': 'value'}
255258

256-
def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
259+
def test_init_with_existing_ids_in_message(
260+
self, mock_message, mock_params
261+
) -> None:
257262
"""Test initialization with existing IDs in the message."""
258263
mock_message.task_id = 'existing-task-id'
259264
mock_message.context_id = 'existing-context-id'
@@ -265,8 +270,8 @@ def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
265270
# No new UUIDs should be generated
266271

267272
def test_init_with_task_id_and_existing_task_id_match(
268-
self, mock_params, mock_task
269-
):
273+
self, mock_params: Mock, mock_task
274+
) -> None:
270275
"""Test initialization succeeds when task_id matches task.id."""
271276
mock_params.message.task_id = mock_task.id
272277

@@ -278,8 +283,8 @@ def test_init_with_task_id_and_existing_task_id_match(
278283
assert context.current_task == mock_task
279284

280285
def test_init_with_context_id_and_existing_context_id_match(
281-
self, mock_params, mock_task
282-
):
286+
self, mock_params: Mock, mock_task: Mock
287+
) -> None:
283288
"""Test initialization succeeds when context_id matches task.context_id."""
284289
mock_params.message.task_id = mock_task.id # Set matching task ID
285290
mock_params.message.context_id = mock_task.context_id
@@ -294,7 +299,7 @@ def test_init_with_context_id_and_existing_context_id_match(
294299
assert context.context_id == mock_task.context_id
295300
assert context.current_task == mock_task
296301

297-
def test_extension_handling(self):
302+
def test_extension_handling(self) -> None:
298303
"""Test extension handling in RequestContext."""
299304
call_context = ServerCallContext(requested_extensions={'foo', 'bar'})
300305
context = RequestContext(call_context=call_context)

0 commit comments

Comments
 (0)