Skip to content

Commit 51026f4

Browse files
committed
Add recording steps and hook to the request handler
1 parent b6002fe commit 51026f4

File tree

9 files changed

+327
-40
lines changed

9 files changed

+327
-40
lines changed

src/a2a/client/base_client.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncIterator
2+
from typing import cast
23

34
from a2a.client.client import (
45
Client,
@@ -12,6 +13,12 @@
1213
from a2a.client.middleware import ClientCallInterceptor
1314
from a2a.client.transports.base import ClientTransport
1415
from a2a.extensions.base import Extension
16+
from a2a.extensions.trace import (
17+
AgentInvocation,
18+
CallTypeEnum,
19+
StepAction,
20+
TraceExtension,
21+
)
1522
from a2a.types import (
1623
AgentCard,
1724
GetTaskPushNotificationConfigParams,
@@ -68,9 +75,31 @@ async def send_message(
6875
Yields:
6976
An async iterator of `ClientEvent` or a final `Message` response.
7077
"""
78+
trace_extension: TraceExtension | None = None
7179
for extension in self._extensions:
80+
if isinstance(extension, TraceExtension):
81+
trace_extension = cast(TraceExtension, extension)
7282
extension.on_client_message(request)
7383

84+
step = None
85+
if trace_extension:
86+
trace_id = request.metadata.get('trace', {}).get('trace_id')
87+
parent_step_id = request.metadata.get('trace', {}).get(
88+
'parent_step_id'
89+
)
90+
step = trace_extension.start_step(
91+
trace_id=trace_id,
92+
parent_step_id=parent_step_id,
93+
call_type=CallTypeEnum.AGENT,
94+
step_action=StepAction(
95+
agent_invocation=AgentInvocation(
96+
agent_url=self._card.url,
97+
agent_name=self._card.name,
98+
requests=request.model_dump(mode='json'),
99+
)
100+
),
101+
)
102+
74103
config = MessageSendConfiguration(
75104
accepted_output_modes=self._config.accepted_output_modes,
76105
blocking=not self._config.polling,
@@ -82,33 +111,44 @@ async def send_message(
82111
)
83112
params = MessageSendParams(message=request, configuration=config)
84113

85-
if not self._config.streaming or not self._card.capabilities.streaming:
86-
response = await self._transport.send_message(
114+
try:
115+
if (
116+
not self._config.streaming
117+
or not self._card.capabilities.streaming
118+
):
119+
response = await self._transport.send_message(
120+
params, context=context
121+
)
122+
result = (
123+
(response, None)
124+
if isinstance(response, Task)
125+
else response
126+
)
127+
await self.consume(result, self._card)
128+
yield result
129+
return
130+
131+
tracker = ClientTaskManager()
132+
stream = self._transport.send_message_streaming(
87133
params, context=context
88134
)
89-
result = (
90-
(response, None) if isinstance(response, Task) else response
91-
)
92-
await self.consume(result, self._card)
93-
yield result
94-
return
95135

96-
tracker = ClientTaskManager()
97-
stream = self._transport.send_message_streaming(params, context=context)
98-
99-
first_event = await anext(stream)
100-
# The response from a server may be either exactly one Message or a
101-
# series of Task updates. Separate out the first message for special
102-
# case handling, which allows us to simplify further stream processing.
103-
if isinstance(first_event, Message):
104-
await self.consume(first_event, self._card)
105-
yield first_event
106-
return
107-
108-
yield await self._process_response(tracker, first_event)
109-
110-
async for event in stream:
111-
yield await self._process_response(tracker, event)
136+
first_event = await anext(stream)
137+
# The response from a server may be either exactly one Message or a
138+
# series of Task updates. Separate out the first message for special
139+
# case handling, which allows us to simplify further stream processing.
140+
if isinstance(first_event, Message):
141+
await self.consume(first_event, self._card)
142+
yield first_event
143+
return
144+
145+
yield await self._process_response(tracker, first_event)
146+
147+
async for event in stream:
148+
yield await self._process_response(tracker, event)
149+
finally:
150+
if trace_extension and step:
151+
trace_extension.end_step(step.step_id)
112152

113153
async def _process_response(
114154
self,

src/a2a/extensions/trace.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from datetime import datetime
3+
import time
4+
import uuid
5+
from datetime import datetime, timezone
46
from enum import Enum
57
from typing import Any
68

@@ -51,7 +53,7 @@ class Step(A2ABaseModel):
5153
additional_attributes: dict[str, str] | None = None
5254
latency: int | None = None
5355
start_time: datetime
54-
end_time: datetime
56+
end_time: datetime | None = None
5557

5658

5759
class ResponseTrace(A2ABaseModel):
@@ -64,18 +66,81 @@ class ResponseTrace(A2ABaseModel):
6466
class TraceExtension(Extension):
6567
"""An extension for traceability."""
6668

69+
def __init__(self, **kwargs: Any) -> None:
70+
super().__init__(**kwargs)
71+
self.traces: dict[str, ResponseTrace] = {}
72+
self._current_steps: dict[str, Step] = {}
73+
74+
def _generate_id(self, prefix: str) -> str:
75+
return f'{prefix}-{uuid.uuid4()}'
76+
77+
def start_trace(self) -> ResponseTrace:
78+
"""Starts a new trace."""
79+
trace_id = self._generate_id('trace')
80+
trace = ResponseTrace(trace_id=trace_id, steps=[])
81+
self.traces[trace_id] = trace
82+
return trace
83+
84+
def start_step(
85+
self,
86+
trace_id: str,
87+
parent_step_id: str | None,
88+
call_type: CallTypeEnum,
89+
step_action: StepAction,
90+
) -> Step:
91+
"""Starts a new step."""
92+
step_id = self._generate_id('step')
93+
step = Step(
94+
step_id=step_id,
95+
trace_id=trace_id,
96+
parent_step_id=parent_step_id,
97+
call_type=call_type,
98+
step_action=step_action,
99+
start_time=datetime.now(timezone.utc),
100+
)
101+
self._current_steps[step_id] = step
102+
return step
103+
104+
def end_step(
105+
self,
106+
step_id: str,
107+
cost: int | None = None,
108+
total_tokens: int | None = None,
109+
additional_attributes: dict[str, str] | None = None,
110+
) -> None:
111+
"""Ends a step."""
112+
if step_id not in self._current_steps:
113+
return
114+
115+
step = self._current_steps.pop(step_id)
116+
step.end_time = datetime.now(timezone.utc)
117+
step.latency = int(
118+
(step.end_time - step.start_time).total_seconds() * 1000
119+
)
120+
step.cost = cost
121+
step.total_tokens = total_tokens
122+
step.additional_attributes = additional_attributes
123+
124+
if step.trace_id in self.traces:
125+
self.traces[step.trace_id].steps.append(step)
126+
67127
def on_client_message(self, message: Any) -> None:
68128
"""Appends trace information to the message."""
69-
# This is a placeholder implementation.
129+
trace = self.start_trace()
70130
if message.metadata is None:
71131
message.metadata = {}
72-
message.metadata['trace'] = 'client-trace'
132+
message.metadata['trace'] = trace.model_dump(mode='json')
73133

74134
def on_server_message(self, message: Any) -> None:
75135
"""Processes trace information from the message."""
76-
# This is a placeholder implementation.
77-
if hasattr(message, 'metadata') and 'trace' in message.metadata:
78-
print(f"Received trace: {message.metadata['trace']}")
136+
if (
137+
hasattr(message, 'metadata')
138+
and message.metadata is not None
139+
and 'trace' in message.metadata
140+
):
141+
trace_data = message.metadata['trace']
142+
trace = ResponseTrace.model_validate(trace_data)
143+
self.traces[trace.trace_id] = trace
79144

80145

81-
AgentInvocation.model_rebuild()
146+
AgentInvocation.model_rebuild()

src/a2a/server/agent_execution/agent_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
23

34
from a2a.server.agent_execution.context import RequestContext
45
from a2a.server.events.event_queue import EventQueue
@@ -13,7 +14,10 @@ class AgentExecutor(ABC):
1314

1415
@abstractmethod
1516
async def execute(
16-
self, context: RequestContext, event_queue: EventQueue
17+
self,
18+
context: RequestContext,
19+
event_queue: EventQueue,
20+
request_handler: Any,
1721
) -> None:
1822
"""Execute the agent's logic for a given request context.
1923
@@ -26,6 +30,7 @@ async def execute(
2630
Args:
2731
context: The request context containing the message, task ID, etc.
2832
event_queue: The queue to publish events to.
33+
request_handler: The request handler that is executing the agent.
2934
"""
3035

3136
@abstractmethod

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
TaskStore,
2828
)
2929
from a2a.extensions.base import Extension
30+
from a2a.extensions.trace import (
31+
CallTypeEnum,
32+
StepAction,
33+
ToolInvocation,
34+
TraceExtension,
35+
)
3036
from a2a.types import (
3137
DeleteTaskPushNotificationConfigParams,
3238
GetTaskPushNotificationConfigParams,
@@ -176,9 +182,40 @@ async def _run_event_stream(
176182
request: The request context for the agent.
177183
queue: The event queue for the agent to publish to.
178184
"""
179-
await self.agent_executor.execute(request, queue)
185+
await self.agent_executor.execute(request, queue, self)
180186
await queue.close()
181187

188+
async def handle_tool_call(
189+
self,
190+
trace_id: str,
191+
parent_step_id: str,
192+
tool_name: str,
193+
parameters: dict[str, Any],
194+
) -> None:
195+
"""Handles a tool call from the agent executor."""
196+
trace_extension: TraceExtension | None = None
197+
for extension in self._extensions:
198+
if isinstance(extension, TraceExtension):
199+
trace_extension = cast(TraceExtension, extension)
200+
201+
if not trace_extension:
202+
return
203+
204+
step = trace_extension.start_step(
205+
trace_id=trace_id,
206+
parent_step_id=parent_step_id,
207+
call_type=CallTypeEnum.TOOL,
208+
step_action=StepAction(
209+
tool_invocation=ToolInvocation(
210+
tool_name=tool_name,
211+
parameters=parameters,
212+
)
213+
),
214+
)
215+
# In a real implementation, you would execute the tool here.
216+
# For this example, we'll just end the step immediately.
217+
trace_extension.end_step(step.step_id)
218+
182219
async def _setup_message_execution(
183220
self,
184221
params: MessageSendParams,

tests/extensions/debug_trace.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
import sys
3+
from unittest.mock import Mock
4+
5+
from a2a.client.base_client import BaseClient
6+
from a2a.extensions.trace import TraceExtension
7+
from a2a.types import Message, TextPart, Role, Part
8+
9+
def debug_trace():
10+
print("Starting trace debug...")
11+
12+
# Create the extension
13+
trace_extension = TraceExtension()
14+
15+
# Create a trace directly to see its structure
16+
trace = trace_extension.start_trace()
17+
print(f"Direct trace object: {trace}")
18+
print(f"Direct trace dict: {trace.model_dump(mode='json')}")
19+
20+
# Create a message
21+
message = Message(
22+
message_id='test_message',
23+
role=Role.user,
24+
parts=[Part(TextPart(text='Hello, world!'))],
25+
)
26+
27+
print(f"Initial message metadata: {message.metadata}")
28+
29+
# Call the extension method
30+
trace_extension.on_client_message(message)
31+
32+
print(f"After extension metadata: {message.metadata}")
33+
34+
if message.metadata and 'trace' in message.metadata:
35+
trace_data = message.metadata['trace']
36+
print(f"Trace data type: {type(trace_data)}")
37+
print(f"Trace data: {trace_data}")
38+
39+
if isinstance(trace_data, dict):
40+
print(f"Trace data keys: {list(trace_data.keys())}")
41+
if 'trace_id' in trace_data:
42+
print(f"Found trace_id: {trace_data['trace_id']}")
43+
else:
44+
print("trace_id not found in trace data")
45+
else:
46+
print("Trace data is not a dict")
47+
else:
48+
print("No trace data found in metadata")
49+
50+
if __name__ == "__main__":
51+
debug_trace()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python3
2+
3+
# Simple test to check ResponseTrace serialization
4+
from a2a.extensions.trace import TraceExtension, ResponseTrace
5+
6+
# Create extension and trace
7+
ext = TraceExtension()
8+
trace = ext.start_trace()
9+
10+
print("Trace object:", trace)
11+
print("Trace type:", type(trace))
12+
print("Trace fields:", trace.__dict__)
13+
print("Model dump:", trace.model_dump(mode='json'))
14+
15+
# Test creating trace data like in the extension
16+
if True: # message.metadata is None
17+
metadata = {}
18+
metadata['trace'] = trace.model_dump(mode='json')
19+
20+
print("Metadata:", metadata)
21+
print("Trace in metadata:", metadata['trace'])
22+
print("Keys in trace:", metadata['trace'].keys() if isinstance(metadata['trace'], dict) else "not a dict")

0 commit comments

Comments
 (0)