diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index f4a8d03d..9389c949 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,4 +1,5 @@ from collections.abc import AsyncIterator +from typing import cast from a2a.client.client import ( Client, @@ -11,6 +12,13 @@ from a2a.client.errors import A2AClientInvalidStateError from a2a.client.middleware import ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.base import Extension +from a2a.extensions.trace import ( + AgentInvocation, + CallTypeEnum, + StepAction, + TraceExtension, +) from a2a.types import ( AgentCard, GetTaskPushNotificationConfigParams, @@ -41,6 +49,12 @@ def __init__( self._card = card self._config = config self._transport = transport + self._extensions: list[Extension] = [] + + def install_extension(self, extension: Extension) -> None: + """Installs an extension on the client.""" + extension.install(self) + self._extensions.append(extension) async def send_message( self, @@ -61,6 +75,31 @@ async def send_message( Yields: An async iterator of `ClientEvent` or a final `Message` response. """ + trace_extension: TraceExtension | None = None + for extension in self._extensions: + if isinstance(extension, TraceExtension): + trace_extension = cast(TraceExtension, extension) + extension.on_client_message(request) + + step = None + if trace_extension: + trace_id = request.metadata.get('trace', {}).get('trace_id') + parent_step_id = request.metadata.get('trace', {}).get( + 'parent_step_id' + ) + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=CallTypeEnum.AGENT, + step_action=StepAction( + agent_invocation=AgentInvocation( + agent_url=self._card.url, + agent_name=self._card.name, + requests=request.model_dump(mode='json'), + ) + ), + ) + config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, @@ -72,33 +111,44 @@ async def send_message( ) params = MessageSendParams(message=request, configuration=config) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport.send_message( + try: + if ( + not self._config.streaming + or not self._card.capabilities.streaming + ): + response = await self._transport.send_message( + params, context=context + ) + result = ( + (response, None) + if isinstance(response, Task) + else response + ) + await self.consume(result, self._card) + yield result + return + + tracker = ClientTaskManager() + stream = self._transport.send_message_streaming( params, context=context ) - result = ( - (response, None) if isinstance(response, Task) else response - ) - await self.consume(result, self._card) - yield result - return - - tracker = ClientTaskManager() - stream = self._transport.send_message_streaming(params, context=context) - first_event = await anext(stream) - # The response from a server may be either exactly one Message or a - # series of Task updates. Separate out the first message for special - # case handling, which allows us to simplify further stream processing. - if isinstance(first_event, Message): - await self.consume(first_event, self._card) - yield first_event - return - - yield await self._process_response(tracker, first_event) - - async for event in stream: - yield await self._process_response(tracker, event) + first_event = await anext(stream) + # The response from a server may be either exactly one Message or a + # series of Task updates. Separate out the first message for special + # case handling, which allows us to simplify further stream processing. + if isinstance(first_event, Message): + await self.consume(first_event, self._card) + yield first_event + return + + yield await self._process_response(tracker, first_event) + + async for event in stream: + yield await self._process_response(tracker, event) + finally: + if trace_extension and step: + trace_extension.end_step(step.step_id) async def _process_response( self, diff --git a/src/a2a/extensions/__init__.py b/src/a2a/extensions/__init__.py index e69de29b..86a77acf 100644 --- a/src/a2a/extensions/__init__.py +++ b/src/a2a/extensions/__init__.py @@ -0,0 +1,6 @@ +"""A2A extensions.""" + +from a2a.extensions.base import Extension +from a2a.extensions import common, trace + +__all__ = ['Extension', 'common', 'trace'] diff --git a/src/a2a/extensions/base.py b/src/a2a/extensions/base.py new file mode 100644 index 00000000..28dbfef7 --- /dev/null +++ b/src/a2a/extensions/base.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from a2a.client.client import A2AClient + from a2a.server.server import A2AServer + + +class Extension: + """Base class for all extensions.""" + + def __init__(self, **kwargs: Any) -> None: + ... + + def on_client_message(self, message: Any) -> None: + """Called when a message is sent from the client.""" + ... + + def on_server_message(self, message: Any) -> None: + """Called when a message is received by the server.""" + ... + + def install(self, client_or_server: A2AClient | A2AServer) -> None: + """Called when the extension is installed on a client or server.""" + ... diff --git a/src/a2a/extensions/trace.py b/src/a2a/extensions/trace.py new file mode 100644 index 00000000..dfa4be01 --- /dev/null +++ b/src/a2a/extensions/trace.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import time +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +from a2a._base import A2ABaseModel +from a2a.extensions.base import Extension + + +class CallTypeEnum(str, Enum): + """The type of the operation a step represents.""" + + AGENT = 'AGENT' + TOOL = 'TOOL' + + +class ToolInvocation(A2ABaseModel): + """A tool invocation.""" + + tool_name: str + parameters: dict[str, Any] + + +class AgentInvocation(A2ABaseModel): + """An agent invocation.""" + + agent_url: str + agent_name: str + requests: dict[str, Any] + response_trace: ResponseTrace | None = None + + +class StepAction(A2ABaseModel): + """The action of a step.""" + + tool_invocation: ToolInvocation | None = None + agent_invocation: AgentInvocation | None = None + + +class Step(A2ABaseModel): + """A single operation within a trace.""" + + step_id: str + trace_id: str + parent_step_id: str | None = None + call_type: CallTypeEnum + step_action: StepAction + cost: int | None = None + total_tokens: int | None = None + additional_attributes: dict[str, str] | None = None + latency: int | None = None + start_time: datetime + end_time: datetime | None = None + + +class ResponseTrace(A2ABaseModel): + """A trace message that contains a collection of spans.""" + + trace_id: str + steps: list[Step] + + +class TraceExtension(Extension): + """An extension for traceability.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.traces: dict[str, ResponseTrace] = {} + self._current_steps: dict[str, Step] = {} + + def _generate_id(self, prefix: str) -> str: + return f'{prefix}-{uuid.uuid4()}' + + def start_trace(self) -> ResponseTrace: + """Starts a new trace.""" + trace_id = self._generate_id('trace') + trace = ResponseTrace(trace_id=trace_id, steps=[]) + self.traces[trace_id] = trace + return trace + + def start_step( + self, + trace_id: str, + parent_step_id: str | None, + call_type: CallTypeEnum, + step_action: StepAction, + ) -> Step: + """Starts a new step.""" + step_id = self._generate_id('step') + step = Step( + step_id=step_id, + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=call_type, + step_action=step_action, + start_time=datetime.now(timezone.utc), + ) + self._current_steps[step_id] = step + return step + + def end_step( + self, + step_id: str, + cost: int | None = None, + total_tokens: int | None = None, + additional_attributes: dict[str, str] | None = None, + ) -> None: + """Ends a step.""" + if step_id not in self._current_steps: + return + + step = self._current_steps.pop(step_id) + step.end_time = datetime.now(timezone.utc) + step.latency = int( + (step.end_time - step.start_time).total_seconds() * 1000 + ) + step.cost = cost + step.total_tokens = total_tokens + step.additional_attributes = additional_attributes + + if step.trace_id in self.traces: + self.traces[step.trace_id].steps.append(step) + + def on_client_message(self, message: Any) -> None: + """Appends trace information to the message.""" + trace = self.start_trace() + if message.metadata is None: + message.metadata = {} + message.metadata['trace'] = trace.model_dump(mode='json') + + def on_server_message(self, message: Any) -> None: + """Processes trace information from the message.""" + if ( + hasattr(message, 'metadata') + and message.metadata is not None + and 'trace' in message.metadata + ): + trace_data = message.metadata['trace'] + trace = ResponseTrace.model_validate(trace_data) + self.traces[trace.trace_id] = trace + + +AgentInvocation.model_rebuild() diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 38be9c11..eabe191d 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue @@ -13,7 +14,10 @@ class AgentExecutor(ABC): @abstractmethod async def execute( - self, context: RequestContext, event_queue: EventQueue + self, + context: RequestContext, + event_queue: EventQueue, + request_handler: Any, ) -> None: """Execute the agent's logic for a given request context. @@ -26,6 +30,7 @@ async def execute( Args: context: The request context containing the message, task ID, etc. event_queue: The queue to publish events to. + request_handler: The request handler that is executing the agent. """ @abstractmethod diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 25fe2525..2a524c50 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncGenerator -from typing import cast +from typing import Any, cast from a2a.server.agent_execution import ( AgentExecutor, @@ -26,6 +26,13 @@ TaskManager, TaskStore, ) +from a2a.extensions.base import Extension +from a2a.extensions.trace import ( + CallTypeEnum, + StepAction, + ToolInvocation, + TraceExtension, +) from a2a.types import ( DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, @@ -101,6 +108,12 @@ def __init__( # noqa: PLR0913 # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() + self._extensions: list[Extension] = [] + + def install_extension(self, extension: Extension, server: Any) -> None: + """Installs an extension on the server.""" + extension.install(server) + self._extensions.append(extension) async def on_get_task( self, @@ -169,9 +182,40 @@ async def _run_event_stream( request: The request context for the agent. queue: The event queue for the agent to publish to. """ - await self.agent_executor.execute(request, queue) + await self.agent_executor.execute(request, queue, self) await queue.close() + async def handle_tool_call( + self, + trace_id: str, + parent_step_id: str, + tool_name: str, + parameters: dict[str, Any], + ) -> None: + """Handles a tool call from the agent executor.""" + trace_extension: TraceExtension | None = None + for extension in self._extensions: + if isinstance(extension, TraceExtension): + trace_extension = cast(TraceExtension, extension) + + if not trace_extension: + return + + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=CallTypeEnum.TOOL, + step_action=StepAction( + tool_invocation=ToolInvocation( + tool_name=tool_name, + parameters=parameters, + ) + ), + ) + # In a real implementation, you would execute the tool here. + # For this example, we'll just end the step immediately. + trace_extension.end_step(step.step_id) + async def _setup_message_execution( self, params: MessageSendParams, @@ -182,6 +226,9 @@ async def _setup_message_execution( Returns: A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) """ + for extension in self._extensions: + extension.on_server_message(params.message) + # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.task_id, diff --git a/tests/extensions/debug_trace.py b/tests/extensions/debug_trace.py new file mode 100644 index 00000000..b6511106 --- /dev/null +++ b/tests/extensions/debug_trace.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +import sys +from unittest.mock import Mock + +from a2a.client.base_client import BaseClient +from a2a.extensions.trace import TraceExtension +from a2a.types import Message, TextPart, Role, Part + +def debug_trace(): + print("Starting trace debug...") + + # Create the extension + trace_extension = TraceExtension() + + # Create a trace directly to see its structure + trace = trace_extension.start_trace() + print(f"Direct trace object: {trace}") + print(f"Direct trace dict: {trace.model_dump(mode='json')}") + + # Create a message + message = Message( + message_id='test_message', + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], + ) + + print(f"Initial message metadata: {message.metadata}") + + # Call the extension method + trace_extension.on_client_message(message) + + print(f"After extension metadata: {message.metadata}") + + if message.metadata and 'trace' in message.metadata: + trace_data = message.metadata['trace'] + print(f"Trace data type: {type(trace_data)}") + print(f"Trace data: {trace_data}") + + if isinstance(trace_data, dict): + print(f"Trace data keys: {list(trace_data.keys())}") + if 'trace_id' in trace_data: + print(f"Found trace_id: {trace_data['trace_id']}") + else: + print("trace_id not found in trace data") + else: + print("Trace data is not a dict") + else: + print("No trace data found in metadata") + +if __name__ == "__main__": + debug_trace() diff --git a/tests/extensions/simple_trace_test.py b/tests/extensions/simple_trace_test.py new file mode 100644 index 00000000..ee1849f6 --- /dev/null +++ b/tests/extensions/simple_trace_test.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# Simple test to check ResponseTrace serialization +from a2a.extensions.trace import TraceExtension, ResponseTrace + +# Create extension and trace +ext = TraceExtension() +trace = ext.start_trace() + +print("Trace object:", trace) +print("Trace type:", type(trace)) +print("Trace fields:", trace.__dict__) +print("Model dump:", trace.model_dump(mode='json')) + +# Test creating trace data like in the extension +if True: # message.metadata is None + metadata = {} +metadata['trace'] = trace.model_dump(mode='json') + +print("Metadata:", metadata) +print("Trace in metadata:", metadata['trace']) +print("Keys in trace:", metadata['trace'].keys() if isinstance(metadata['trace'], dict) else "not a dict") diff --git a/tests/extensions/test_full_trace_extension.py b/tests/extensions/test_full_trace_extension.py new file mode 100644 index 00000000..2a664e8f --- /dev/null +++ b/tests/extensions/test_full_trace_extension.py @@ -0,0 +1,58 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from a2a.extensions.trace import TraceExtension +from a2a.types import Message, Part, Role, TextPart + + +@pytest.mark.asyncio +async def test_full_trace_extension(): + trace_extension = TraceExtension() + + # Test the trace extension directly + message = Message( + message_id='test_message', + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], + ) + + # Simulate client sending a message - creates trace + trace_extension.on_client_message(message) + + # Verify trace was created and stored in metadata + assert 'trace' in message.metadata + trace_data = message.metadata['trace'] + assert 'traceId' in trace_data + trace_id = trace_data['traceId'] + + # Simulate server receiving a message - loads trace + trace_extension.on_server_message(message) + + # Verify trace was loaded into extension + assert trace_id in trace_extension.traces + trace = trace_extension.traces[trace_id] + assert len(trace.steps) == 0 # Initially no steps + + # Simulate a tool call being made + from a2a.extensions.trace import StepAction, ToolInvocation, CallTypeEnum + step_action = StepAction(tool_invocation=ToolInvocation( + tool_name='test_tool', + parameters={'param1': 'value1'} + )) + + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=None, + call_type=CallTypeEnum.TOOL, + step_action=step_action + ) + + # End the step + trace_extension.end_step(step.step_id) + + # Verify the trace + assert len(trace_extension.traces) == 1 + trace = trace_extension.traces[trace_id] + assert len(trace.steps) == 1 + assert trace.steps[0].call_type == CallTypeEnum.TOOL diff --git a/tests/extensions/test_trace.py b/tests/extensions/test_trace.py new file mode 100644 index 00000000..850d8007 --- /dev/null +++ b/tests/extensions/test_trace.py @@ -0,0 +1,63 @@ +from datetime import datetime, timezone + +from a2a.extensions.trace import ( + AgentInvocation, + CallTypeEnum, + ResponseTrace, + Step, + StepAction, + ToolInvocation, +) + + +def test_trace_serialization(): + start_time = datetime(2025, 3, 15, 12, 0, 0, tzinfo=timezone.utc) + end_time = datetime(2025, 3, 15, 12, 0, 0, 250000, tzinfo=timezone.utc) + + trace = ResponseTrace( + trace_id='trace-example-12345', + steps=[ + Step( + step_id='step-1-agent', + trace_id='trace-example-1234p', + call_type=CallTypeEnum.AGENT, + step_action=StepAction( + agent_invocation=AgentInvocation( + agent_name='weather_agent', + agent_url='http://google3/some/agent/url', + requests={ + 'user_prompt': "What's the weather in Paris and what should I wear?" + }, + ) + ), + cost=150, + total_tokens=75, + additional_attributes={'user_country': 'US'}, + latency=250, + start_time=start_time, + end_time=end_time, + ), + Step( + step_id='step-2-tool', + trace_id='trace-example-12345', + parent_step_id='step-1-agent', + call_type=CallTypeEnum.TOOL, + step_action=StepAction( + tool_invocation=ToolInvocation( + tool_name='google_map_api_tool', + parameters={'location': 'Paris, FR'}, + ) + ), + cost=50, + total_tokens=20, + latency=100, + start_time=start_time, + end_time=end_time, + ), + ], + ) + + trace_dict = trace.model_dump(mode='json') + deserialized_trace = ResponseTrace.model_validate(trace_dict) + + assert trace == deserialized_trace diff --git a/tests/extensions/test_trace_extension.py b/tests/extensions/test_trace_extension.py new file mode 100644 index 00000000..a32329f1 --- /dev/null +++ b/tests/extensions/test_trace_extension.py @@ -0,0 +1,44 @@ +from unittest.mock import Mock + +import pytest + +from a2a.client.base_client import BaseClient +from a2a.extensions.trace import TraceExtension +from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler +from a2a.types import Message, TextPart, Part, Role + + +@pytest.mark.asyncio +async def test_trace_extension(): + client = BaseClient(card=Mock(), config=Mock(), transport=Mock(), consumers=[], middleware=[]) + server_handler = DefaultRequestHandler( + agent_executor=Mock(), + task_store=Mock(), + ) + + trace_extension = TraceExtension() + client.install_extension(trace_extension) + server_handler.install_extension(trace_extension, server=Mock()) + + message = Message( + message_id='test_message', + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], + ) + + # Simulate client sending a message + for extension in client._extensions: + extension.on_client_message(message) + + assert 'trace' in message.metadata + # The trace_id field is serialized as traceId due to camelCase alias generator + assert isinstance(message.metadata['trace']['traceId'], str) + + # Simulate server receiving a message + for extension in server_handler._extensions: + extension.on_server_message(message) + + # Check that the server-side handler was called + # (in this case, it just prints a message) + # We can't easily check the output of print, so we'll just + # assume it worked if no exceptions were raised. diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6cb21662..0ea8380e 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -2,6 +2,7 @@ import logging import time +from typing import Any from unittest.mock import ( AsyncMock, MagicMock, @@ -57,7 +58,9 @@ class DummyAgentExecutor(AgentExecutor): - async def execute(self, context: RequestContext, event_queue: EventQueue): + async def execute( + self, context: RequestContext, event_queue: EventQueue, request_handler: Any + ): task_updater = TaskUpdater( event_queue, context.task_id, context.context_id ) @@ -584,7 +587,12 @@ async def test_on_message_send_task_id_mismatch(): class HelloAgentExecutor(AgentExecutor): - async def execute(self, context: RequestContext, event_queue: EventQueue): + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + request_handler: Any, + ): task = context.current_task if not task: assert context.message is not None, (