Skip to content

Commit 643a2e0

Browse files
author
lkawka
committed
Add client support for task/list method
1 parent e5853a6 commit 643a2e0

File tree

8 files changed

+265
-6
lines changed

8 files changed

+265
-6
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from collections.abc import AsyncGenerator
44

5+
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
6+
57

68
try:
79
import grpc
@@ -154,8 +156,11 @@ async def list_tasks(
154156
context: ClientCallContext | None = None,
155157
) -> ListTasksResult:
156158
"""Retrieves tasks for an agent."""
157-
# TODO: #515 - Implement method
158-
raise NotImplementedError('tasks/list not implemented')
159+
response = await self.stub.ListTasks(
160+
proto_utils.ToProto.list_tasks_request(request)
161+
)
162+
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
163+
return proto_utils.FromProto.list_tasks_result(response, page_size)
159164

160165
async def cancel_task(
161166
self,

src/a2a/client/transports/jsonrpc.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
GetTaskResponse,
3232
JSONRPCErrorResponse,
3333
ListTasksParams,
34+
ListTasksRequest,
35+
ListTasksResponse,
3436
ListTasksResult,
3537
Message,
3638
MessageSendParams,
@@ -231,8 +233,18 @@ async def list_tasks(
231233
context: ClientCallContext | None = None,
232234
) -> ListTasksResult:
233235
"""Retrieves tasks for an agent."""
234-
# TODO: #515 - Implement method
235-
raise NotImplementedError('tasks/list not implemented')
236+
rpc_request = ListTasksRequest(params=request, id=str(uuid4()))
237+
payload, modified_kwargs = await self._apply_interceptors(
238+
'tasks/list',
239+
rpc_request.model_dump(mode='json', exclude_none=True),
240+
self._get_http_args(context),
241+
context,
242+
)
243+
response_data = await self._send_request(payload, modified_kwargs)
244+
response = ListTasksResponse.model_validate(response_data)
245+
if isinstance(response.root, JSONRPCErrorResponse):
246+
raise A2AClientJSONRPCError(response.root)
247+
return response.root.result
236248

237249
async def cancel_task(
238250
self,

src/a2a/client/transports/rest.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from google.protobuf.json_format import MessageToDict, Parse, ParseDict
1010
from httpx_sse import SSEError, aconnect_sse
11+
from pydantic import BaseModel
1112

1213
from a2a.client.card_resolver import A2ACardResolver
1314
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
@@ -29,6 +30,7 @@
2930
TaskStatusUpdateEvent,
3031
)
3132
from a2a.utils import proto_utils
33+
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
3234
from a2a.utils.telemetry import SpanKind, trace_class
3335

3436

@@ -231,8 +233,20 @@ async def list_tasks(
231233
context: ClientCallContext | None = None,
232234
) -> ListTasksResult:
233235
"""Retrieves tasks for an agent."""
234-
# TODO: #515 - Implement method
235-
raise NotImplementedError('tasks/list not implemented')
236+
_, modified_kwargs = await self._apply_interceptors(
237+
request.model_dump(mode='json', exclude_none=True),
238+
self._get_http_args(context),
239+
context,
240+
)
241+
response_data = await self._send_get_request(
242+
'/v1/tasks',
243+
_model_to_query_params(request),
244+
modified_kwargs,
245+
)
246+
response = a2a_pb2.ListTasksResponse()
247+
ParseDict(response_data, response)
248+
page_size = request.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
249+
return proto_utils.FromProto.list_tasks_result(response, page_size)
236250

237251
async def cancel_task(
238252
self,
@@ -375,3 +389,21 @@ async def get_card(
375389
async def close(self) -> None:
376390
"""Closes the httpx client."""
377391
await self.httpx_client.aclose()
392+
393+
394+
def _model_to_query_params(instance: BaseModel) -> dict[str, str]:
395+
data = instance.model_dump(mode='json', exclude_none=True)
396+
return _json_to_query_params(data)
397+
398+
399+
def _json_to_query_params(data: dict[str, Any]) -> dict[str, str]:
400+
query_dict = {}
401+
for key, value in data.items():
402+
if isinstance(value, list):
403+
query_dict[key] = ','.join(map(str, value))
404+
elif isinstance(value, bool):
405+
query_dict[key] = str(value).lower()
406+
else:
407+
query_dict[key] = str(value)
408+
409+
return query_dict

src/a2a/utils/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
PREV_AGENT_CARD_WELL_KNOWN_PATH = '/.well-known/agent.json'
55
EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard'
66
DEFAULT_RPC_URL = '/'
7+
DEFAULT_LIST_TASKS_PAGE_SIZE = 50
8+
"""Default page size for the `tasks/list` method."""

src/a2a/utils/proto_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99

1010
from google.protobuf import json_format, struct_pb2
11+
from google.protobuf.timestamp_pb2 import Timestamp
1112

1213
from a2a import types
1314
from a2a.grpc import a2a_pb2
@@ -566,6 +567,24 @@ def role(cls, role: types.Role) -> a2a_pb2.Role:
566567
case _:
567568
return a2a_pb2.Role.ROLE_UNSPECIFIED
568569

570+
@classmethod
571+
def list_tasks_request(
572+
cls, params: types.ListTasksParams
573+
) -> a2a_pb2.ListTasksRequest:
574+
last_updated_time = None
575+
if params.last_updated_after is not None:
576+
last_updated_time = Timestamp()
577+
last_updated_time.FromMilliseconds(params.last_updated_after)
578+
return a2a_pb2.ListTasksRequest(
579+
context_id=params.context_id,
580+
status=cls.task_state(params.status) if params.status else None,
581+
page_size=params.page_size,
582+
page_token=params.page_token,
583+
history_length=params.history_length,
584+
last_updated_time=last_updated_time,
585+
include_artifacts=params.include_artifacts,
586+
)
587+
569588

570589
class FromProto:
571590
"""Converts proto types to Python types."""
@@ -796,6 +815,28 @@ def task_id_params(
796815
)
797816
return types.TaskIdParams(id=m.group(1))
798817

818+
@classmethod
819+
def list_tasks_result(
820+
cls,
821+
response: a2a_pb2.ListTasksResponse,
822+
page_size: int,
823+
) -> types.ListTasksResult:
824+
"""Converts a ListTasksResponse to a ListTasksResult.
825+
826+
Args:
827+
response: The ListTasksResponse to convert.
828+
page_size: The maximum number of tasks returned in this response.
829+
830+
Returns:
831+
A `ListTasksResult` object.
832+
"""
833+
return types.ListTasksResult(
834+
next_page_token=response.next_page_token,
835+
page_size=page_size,
836+
tasks=[cls.task(t) for t in response.tasks],
837+
total_size=response.total_size,
838+
)
839+
799840
@classmethod
800841
def task_push_notification_config_request(
801842
cls,

tests/client/test_grpc_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TaskStatus,
2626
TaskStatusUpdateEvent,
2727
TextPart,
28+
ListTasksParams,
2829
)
2930
from a2a.utils import get_text_parts, proto_utils
3031
from a2a.utils.errors import ServerError
@@ -37,6 +38,7 @@ def mock_grpc_stub() -> AsyncMock:
3738
stub.SendMessage = AsyncMock()
3839
stub.SendStreamingMessage = MagicMock()
3940
stub.GetTask = AsyncMock()
41+
stub.ListTasks = AsyncMock()
4042
stub.CancelTask = AsyncMock()
4143
stub.CreateTaskPushNotificationConfig = AsyncMock()
4244
stub.GetTaskPushNotificationConfig = AsyncMock()
@@ -91,6 +93,16 @@ def sample_task() -> Task:
9193
)
9294

9395

96+
@pytest.fixture
97+
def sample_task_2() -> Task:
98+
"""Provides a sample Task object."""
99+
return Task(
100+
id='task-2',
101+
context_id='ctx-2',
102+
status=TaskStatus(state=TaskState.failed),
103+
)
104+
105+
94106
@pytest.fixture
95107
def sample_message() -> Message:
96108
"""Provides a sample Message object."""
@@ -283,6 +295,32 @@ async def test_get_task(
283295
assert response.id == sample_task.id
284296

285297

298+
@pytest.mark.asyncio
299+
async def test_list_tasks(
300+
grpc_transport: GrpcTransport,
301+
mock_grpc_stub: AsyncMock,
302+
sample_task: Task,
303+
sample_task_2: Task,
304+
):
305+
"""Test listing tasks."""
306+
mock_grpc_stub.ListTasks.return_value = a2a_pb2.ListTasksResponse(
307+
tasks=[
308+
proto_utils.ToProto.task(t) for t in [sample_task, sample_task_2]
309+
],
310+
total_size=2,
311+
)
312+
params = ListTasksParams()
313+
314+
result = await grpc_transport.list_tasks(params)
315+
316+
mock_grpc_stub.ListTasks.assert_awaited_once_with(
317+
proto_utils.ToProto.list_tasks_request(params)
318+
)
319+
assert result.total_size == 2
320+
assert not result.next_page_token
321+
assert [t.id for t in result.tasks] == [sample_task.id, sample_task_2.id]
322+
323+
286324
@pytest.mark.asyncio
287325
async def test_get_task_with_history(
288326
grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task

tests/client/test_jsonrpc_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
TaskIdParams,
3232
TaskPushNotificationConfig,
3333
TaskQueryParams,
34+
ListTasksParams,
35+
ListTasksResult,
3436
)
3537
from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH
3638

@@ -560,6 +562,42 @@ async def test_get_task_success(
560562
sent_payload = mock_send_request.call_args.args[0]
561563
assert sent_payload['method'] == 'tasks/get'
562564

565+
@pytest.mark.asyncio
566+
async def test_list_tasks_success(
567+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
568+
):
569+
client = JsonRpcTransport(
570+
httpx_client=mock_httpx_client, agent_card=mock_agent_card
571+
)
572+
params = ListTasksParams()
573+
mock_rpc_response = {
574+
'id': '123',
575+
'jsonrpc': '2.0',
576+
'result': {
577+
'nextPageToken': '',
578+
'tasks': [MINIMAL_TASK],
579+
'pageSize': 10,
580+
'totalSize': 1,
581+
},
582+
}
583+
584+
with patch.object(
585+
client, '_send_request', new_callable=AsyncMock
586+
) as mock_send_request:
587+
mock_send_request.return_value = mock_rpc_response
588+
response = await client.list_tasks(request=params)
589+
590+
assert isinstance(response, ListTasksResult)
591+
assert (
592+
response.model_dump()
593+
== ListTasksResult(
594+
next_page_token='',
595+
page_size=10,
596+
tasks=[Task.model_validate(MINIMAL_TASK)],
597+
total_size=1,
598+
).model_dump()
599+
)
600+
563601
@pytest.mark.asyncio
564602
async def test_cancel_task_success(
565603
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock

0 commit comments

Comments
 (0)