Skip to content

Commit 27f633b

Browse files
committed
feat: add traceability extension support
1 parent dec4b48 commit 27f633b

File tree

8 files changed

+242
-3
lines changed

8 files changed

+242
-3
lines changed

src/a2a/client/base_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from a2a.client.errors import A2AClientInvalidStateError
1212
from a2a.client.middleware import ClientCallInterceptor
1313
from a2a.client.transports.base import ClientTransport
14+
from a2a.extensions.base import Extension
1415
from a2a.types import (
1516
AgentCard,
1617
GetTaskPushNotificationConfigParams,
@@ -41,6 +42,12 @@ def __init__(
4142
self._card = card
4243
self._config = config
4344
self._transport = transport
45+
self._extensions: list[Extension] = []
46+
47+
def install_extension(self, extension: Extension) -> None:
48+
"""Installs an extension on the client."""
49+
extension.install(self)
50+
self._extensions.append(extension)
4451

4552
async def send_message(
4653
self,
@@ -61,6 +68,9 @@ async def send_message(
6168
Yields:
6269
An async iterator of `ClientEvent` or a final `Message` response.
6370
"""
71+
for extension in self._extensions:
72+
extension.on_client_message(request)
73+
6474
config = MessageSendConfiguration(
6575
accepted_output_modes=self._config.accepted_output_modes,
6676
blocking=not self._config.polling,

src/a2a/extensions/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""A2A extensions."""
2+
3+
from .base import Extension
4+
from . import common, trace
5+
6+
__all__ = ['Extension', 'common', 'trace']

src/a2a/extensions/base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from a2a.client.client import A2AClient
7+
from a2a.server.server import A2AServer
8+
9+
10+
class Extension:
11+
"""Base class for all extensions."""
12+
13+
def __init__(self, **kwargs: Any) -> None:
14+
...
15+
16+
def on_client_message(self, message: Any) -> None:
17+
"""Called when a message is sent from the client."""
18+
...
19+
20+
def on_server_message(self, message: Any) -> None:
21+
"""Called when a message is received by the server."""
22+
...
23+
24+
def install(self, client_or_server: A2AClient | A2AServer) -> None:
25+
"""Called when the extension is installed on a client or server."""
26+
...

src/a2a/extensions/trace.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from enum import Enum
5+
from typing import Any
6+
7+
from a2a._base import A2ABaseModel
8+
from a2a.extensions.base import Extension
9+
10+
11+
class CallTypeEnum(str, Enum):
12+
"""The type of the operation a step represents."""
13+
14+
AGENT = 'AGENT'
15+
TOOL = 'TOOL'
16+
17+
18+
class ToolInvocation(A2ABaseModel):
19+
"""A tool invocation."""
20+
21+
tool_name: str
22+
parameters: dict[str, Any]
23+
24+
25+
class AgentInvocation(A2ABaseModel):
26+
"""An agent invocation."""
27+
28+
agent_url: str
29+
agent_name: str
30+
requests: dict[str, Any]
31+
response_trace: ResponseTrace | None = None
32+
33+
34+
class StepAction(A2ABaseModel):
35+
"""The action of a step."""
36+
37+
tool_invocation: ToolInvocation | None = None
38+
agent_invocation: AgentInvocation | None = None
39+
40+
41+
class Step(A2ABaseModel):
42+
"""A single operation within a trace."""
43+
44+
step_id: str
45+
trace_id: str
46+
parent_step_id: str | None = None
47+
call_type: CallTypeEnum
48+
step_action: StepAction
49+
cost: int | None = None
50+
total_tokens: int | None = None
51+
additional_attributes: dict[str, str] | None = None
52+
latency: int | None = None
53+
start_time: datetime
54+
end_time: datetime
55+
56+
57+
class ResponseTrace(A2ABaseModel):
58+
"""A trace message that contains a collection of spans."""
59+
60+
trace_id: str
61+
steps: list[Step]
62+
63+
64+
class TraceExtension(Extension):
65+
"""An extension for traceability."""
66+
67+
def on_client_message(self, message: Any) -> None:
68+
"""Appends trace information to the message."""
69+
# This is a placeholder implementation.
70+
if message.metadata is None:
71+
message.metadata = {}
72+
message.metadata['trace'] = 'client-trace'
73+
74+
def on_server_message(self, message: Any) -> None:
75+
"""Processes trace information from the message."""
76+
# This is a placeholder implementation.
77+
if hasattr(message, 'metadata') and 'trace' in message.metadata:
78+
print(f"Received trace: {message.metadata['trace']}")
79+
80+
81+
AgentInvocation.model_rebuild()

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33

44
from collections.abc import AsyncGenerator
5-
from typing import cast
5+
from typing import Any, cast
66

77
from a2a.server.agent_execution import (
88
AgentExecutor,
@@ -26,6 +26,7 @@
2626
TaskManager,
2727
TaskStore,
2828
)
29+
from a2a.extensions.base import Extension
2930
from a2a.types import (
3031
DeleteTaskPushNotificationConfigParams,
3132
GetTaskPushNotificationConfigParams,
@@ -101,6 +102,12 @@ def __init__( # noqa: PLR0913
101102
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
102103
self._running_agents = {}
103104
self._running_agents_lock = asyncio.Lock()
105+
self._extensions: list[Extension] = []
106+
107+
def install_extension(self, extension: Extension, server: Any) -> None:
108+
"""Installs an extension on the server."""
109+
extension.install(server)
110+
self._extensions.append(extension)
104111

105112
async def on_get_task(
106113
self,
@@ -182,6 +189,9 @@ async def _setup_message_execution(
182189
Returns:
183190
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
184191
"""
192+
for extension in self._extensions:
193+
extension.on_server_message(params.message)
194+
185195
# Create task manager and validate existing task
186196
task_manager = TaskManager(
187197
task_id=params.message.task_id,

tests/extensions/test_trace.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from datetime import datetime, timezone
2+
3+
from a2a.extensions.trace import (
4+
AgentInvocation,
5+
CallTypeEnum,
6+
ResponseTrace,
7+
Step,
8+
StepAction,
9+
ToolInvocation,
10+
)
11+
12+
13+
def test_trace_serialization():
14+
start_time = datetime(2025, 3, 15, 12, 0, 0, tzinfo=timezone.utc)
15+
end_time = datetime(2025, 3, 15, 12, 0, 0, 250000, tzinfo=timezone.utc)
16+
17+
trace = ResponseTrace(
18+
trace_id='trace-example-12345',
19+
steps=[
20+
Step(
21+
step_id='step-1-agent',
22+
trace_id='trace-example-1234p',
23+
call_type=CallTypeEnum.AGENT,
24+
step_action=StepAction(
25+
agent_invocation=AgentInvocation(
26+
agent_name='weather_agent',
27+
agent_url='http://google3/some/agent/url',
28+
requests={
29+
'user_prompt': "What's the weather in Paris and what should I wear?"
30+
},
31+
)
32+
),
33+
cost=150,
34+
total_tokens=75,
35+
additional_attributes={'user_country': 'US'},
36+
latency=250,
37+
start_time=start_time,
38+
end_time=end_time,
39+
),
40+
Step(
41+
step_id='step-2-tool',
42+
trace_id='trace-example-12345',
43+
parent_step_id='step-1-agent',
44+
call_type=CallTypeEnum.TOOL,
45+
step_action=StepAction(
46+
tool_invocation=ToolInvocation(
47+
tool_name='google_map_api_tool',
48+
parameters={'location': 'Paris, FR'},
49+
)
50+
),
51+
cost=50,
52+
total_tokens=20,
53+
latency=100,
54+
start_time=start_time,
55+
end_time=end_time,
56+
),
57+
],
58+
)
59+
60+
trace_dict = trace.model_dump(mode='json')
61+
deserialized_trace = ResponseTrace.model_validate(trace_dict)
62+
63+
assert trace == deserialized_trace
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from a2a.client.base_client import BaseClient
6+
from a2a.extensions.trace import TraceExtension
7+
from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler
8+
from a2a.types import Message, TextPart
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_trace_extension():
13+
client = BaseClient(card=Mock(), config=Mock(), transport=Mock(), consumers=[], middleware=[])
14+
server_handler = DefaultRequestHandler(
15+
agent_executor=Mock(),
16+
task_store=Mock(),
17+
)
18+
19+
trace_extension = TraceExtension()
20+
client.install_extension(trace_extension)
21+
server_handler.install_extension(trace_extension, server=Mock())
22+
23+
message = Message(
24+
message_id='test_message',
25+
role='user',
26+
parts=[TextPart(text='Hello, world!')],
27+
)
28+
29+
# Simulate client sending a message
30+
for extension in client._extensions:
31+
extension.on_client_message(message)
32+
33+
assert 'trace' in message.metadata
34+
assert message.metadata['trace'] == 'client-trace'
35+
36+
# Simulate server receiving a message
37+
for extension in server_handler._extensions:
38+
extension.on_server_message(message)
39+
40+
# Check that the server-side handler was called
41+
# (in this case, it just prints a message)
42+
# We can't easily check the output of print, so we'll just
43+
# assume it worked if no exceptions were raised.

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)