diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 8a916b5cc..85c82520c 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -42,8 +42,10 @@ Message, Part, Skill, + TaskArtifactUpdateEvent, TaskIdParams, TaskSendParams, + TaskStatusUpdateEvent, TextPart as A2ATextPart, ) from fasta2a.storage import InMemoryStorage, Storage @@ -72,6 +74,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, def agent_to_a2a( agent: Agent[AgentDepsT, OutputDataT], *, + enable_streaming: bool = False, storage: Storage | None = None, broker: Broker | None = None, # Agent card @@ -91,7 +94,7 @@ def agent_to_a2a( """Create a FastA2A server from an agent.""" storage = storage or InMemoryStorage() broker = broker or InMemoryBroker() - worker = AgentWorker(agent=agent, broker=broker, storage=storage) + worker = AgentWorker(agent=agent, broker=broker, storage=storage, enable_streaming=enable_streaming) lifespan = lifespan or partial(worker_lifespan, worker=worker, agent=agent) @@ -104,6 +107,7 @@ def agent_to_a2a( description=description, provider=provider, skills=skills, + streaming=enable_streaming, debug=debug, routes=routes, middleware=middleware, @@ -117,6 +121,7 @@ class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT] """A worker that uses an agent to execute tasks.""" agent: Agent[AgentDepsT, WorkerOutputT] + enable_streaming: bool = False async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id']) @@ -132,39 +137,119 @@ async def run_task(self, params: TaskSendParams) -> None: await self.storage.update_task(task['id'], state='working') + # Send working status streaming event + await self.broker.send_stream_event( + task['id'], + TaskStatusUpdateEvent( + task_id=task['id'], + context_id=task['context_id'], + kind='status-update', + status={'state': 'working'}, + final=False, + ), + ) + # Load context - contains pydantic-ai message history from previous tasks in this conversation message_history = await self.storage.load_context(task['context_id']) or [] message_history.extend(self.build_message_history(task.get('history', []))) try: - result = await self.agent.run(message_history=message_history) # type: ignore - - await self.storage.update_context(task['context_id'], result.all_messages()) - - # Convert new messages to A2A format for task history - a2a_messages: list[Message] = [] - - for message in result.new_messages(): - if isinstance(message, ModelRequest): - # Skip user prompts - they're already in task history - continue - else: - # Convert response parts to A2A format - a2a_parts = self._response_parts_to_a2a(message.parts) - if a2a_parts: # Add if there are visible parts (text/thinking) - a2a_messages.append( - Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4())) + # Stream processing with agent.iter() + async with self.agent.iter(message_history=message_history, deps=None) as run: # type: ignore + node = run.next_node + while not self.agent.is_end_node(node): + # Check if this node has a model response + if hasattr(node, 'model_response'): + model_response = getattr(node, 'model_response') + # Convert model response parts to A2A parts + a2a_parts = self._response_parts_to_a2a(model_response.parts) + + if a2a_parts and self.enable_streaming: + # Send incremental message event with unique ID + incremental_message = Message( + role='agent', + parts=a2a_parts, + kind='message', + message_id=str(uuid.uuid4()), # Generate unique ID per message + ) + # Stream the incremental message + await self.broker.send_stream_event(task['id'], incremental_message) + + # Move to next node + current = node + node = await run.next(current) + + if self.enable_streaming: + # Update context with current messages after each step + await self.storage.update_context(task['context_id'], run.ctx.state.message_history) + + # Run finished - get the final result and update context with final messages + assert run.result is not None # Agent iteration should always produce a result + await self.storage.update_context(task['context_id'], run.result.all_messages()) + + # Convert new messages to A2A format for task history + a2a_messages: list[Message] = [] + + for message in run.result.new_messages(): + if isinstance(message, ModelRequest): + # Skip user prompts - they're already in task history + continue + else: + # Convert response parts to A2A format + a2a_parts = self._response_parts_to_a2a(message.parts) + if a2a_parts: # Add if there are visible parts (text/thinking) + a2a_messages.append( + Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4())) + ) + + # Handle final result and create artifacts using build_artifacts method + artifacts = self.build_artifacts(run.result.output) + + # Send artifact update events for all artifacts (only for structured outputs) + if not isinstance(run.result.output, str): + for artifact in artifacts: + await self.broker.send_stream_event( + task['id'], + TaskArtifactUpdateEvent( + task_id=task['id'], + context_id=task['context_id'], + kind='artifact-update', + artifact=artifact, + last_chunk=True, + ), ) - - artifacts = self.build_artifacts(result.output) except Exception: await self.storage.update_task(task['id'], state='failed') + + # Send failure status streaming event + await self.broker.send_stream_event( + task['id'], + TaskStatusUpdateEvent( + task_id=task['id'], + context_id=task['context_id'], + kind='status-update', + status={'state': 'failed'}, + final=True, + ), + ) raise else: await self.storage.update_task( task['id'], state='completed', new_artifacts=artifacts, new_messages=a2a_messages ) + # Send completion status streaming event + await self.broker.send_stream_event( + task['id'], + TaskStatusUpdateEvent( + task_id=task['id'], + context_id=task['context_id'], + kind='status-update', + status={'state': 'completed'}, + final=True, + ), + ) + async def cancel_task(self, params: TaskIdParams) -> None: pass diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 5f22d7329..99dc50348 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1950,6 +1950,7 @@ def to_ag_ui( def to_a2a( self, *, + enable_streaming: bool = False, storage: Storage | None = None, broker: Broker | None = None, # Agent card @@ -1988,6 +1989,7 @@ def to_a2a( return agent_to_a2a( self, + enable_streaming=enable_streaming, storage=storage, broker=broker, name=name, diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c13d1da54..6ab2b2b43 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -1,4 +1,5 @@ import uuid +from typing import Any, cast import anyio import httpx @@ -24,6 +25,7 @@ from .conftest import IsDatetime, IsStr, try_import with try_import() as imports_successful: + from fasta2a.broker import StreamEvent from fasta2a.client import A2AClient from fasta2a.schema import DataPart, FilePart, Message, TextPart from fasta2a.storage import InMemoryStorage @@ -991,3 +993,180 @@ async def test_a2a_multiple_send_task_messages(): ], } ) + + +async def test_streaming_emits_incremental_messages(mocker: Any) -> None: + """Verify that enable_streaming=True produces incremental messages during agent execution.""" + from fasta2a.broker import InMemoryBroker + + # Create a model that produces multiple text parts to simulate streaming + def return_multiple_text_parts(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[ + PydanticAITextPart(content='First part of response'), + PydanticAITextPart(content='Second part of response'), + PydanticAITextPart(content='Final part'), + ] + ) + + streaming_model = FunctionModel(return_multiple_text_parts) + + # Create agent with streaming enabled + agent = Agent(model=streaming_model, output_type=str) + storage = InMemoryStorage() + broker = InMemoryBroker() + + # Spy on the broker's send_stream_event method to capture calls + mock_send: Any = mocker.spy(broker, 'send_stream_event') + + app = agent.to_a2a(enable_streaming=True, storage=storage, broker=broker) + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + + result = response['result'] + assert result['kind'] == 'task' + task_id = result['id'] + + # Wait for task completion + await anyio.sleep(0.1) + final_response = await a2a_client.get_task(task_id) + assert 'result' in final_response + final_result = final_response['result'] + assert final_result['status']['state'] == 'completed' + + # Verify streaming events were captured + assert mock_send.call_count > 0 + + # Extract events from mock calls + captured_events: list[StreamEvent] = [call.args[1] for call in mock_send.call_args_list] + + # Check that we got different types of events + event_kinds = [event.get('kind') for event in captured_events if event.get('kind')] + + # Look for agent messages + message_events: list[Message] = [] + for event in captured_events: + if event.get('kind') == 'message' and event.get('role') == 'agent': + message_events.append(cast(Message, event)) + + # Should have status updates at minimum + assert 'status-update' in event_kinds + + # Verify we got at least one agent message during streaming + assert len(message_events) > 0, f'Expected agent messages during streaming, got events: {event_kinds}' + + # Verify the agent message contains the expected content + agent_message = message_events[0] + assert agent_message['role'] == 'agent' + assert agent_message['kind'] == 'message' + assert 'message_id' in agent_message + assert 'parts' in agent_message + parts = agent_message['parts'] + assert len(parts) == 3 # Should have 3 text parts + first_part = parts[0] + assert first_part.get('kind') == 'text' + assert first_part.get('text') == 'First part of response' + + +async def test_streaming_disabled_sends_only_final_results(mocker: Any) -> None: + """Verify enable_streaming=False sends only status updates and final results, no incremental messages.""" + from fasta2a.broker import InMemoryBroker + + # Create a model that produces multiple text parts - same as streaming test for comparison + def return_multiple_text_parts(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[ + PydanticAITextPart(content='First part of response'), + PydanticAITextPart(content='Second part of response'), + PydanticAITextPart(content='Final part'), + ] + ) + + streaming_model = FunctionModel(return_multiple_text_parts) + + # Create agent with streaming DISABLED (default behavior) + agent = Agent(model=streaming_model, output_type=str) + storage = InMemoryStorage() + broker = InMemoryBroker() + + # Spy on the broker's send_stream_event method to capture calls + mock_send: Any = mocker.spy(broker, 'send_stream_event') + + # Note: enable_streaming defaults to False, but being explicit for clarity + app = agent.to_a2a(enable_streaming=False, storage=storage, broker=broker) + + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as http_client: + a2a_client = A2AClient(http_client=http_client) + + message = Message( + role='user', + parts=[TextPart(text='Hello, world!', kind='text')], + kind='message', + message_id=str(uuid.uuid4()), + ) + response = await a2a_client.send_message(message=message) + assert 'error' not in response + assert 'result' in response + + result = response['result'] + assert result['kind'] == 'task' + task_id = result['id'] + + # Wait for task completion + await anyio.sleep(0.1) + final_response = await a2a_client.get_task(task_id) + assert 'result' in final_response + final_result = final_response['result'] + assert final_result['status']['state'] == 'completed' + + # Verify streaming events were captured + assert mock_send.call_count > 0 + + # Extract events from mock calls + captured_events: list[StreamEvent] = [call.args[1] for call in mock_send.call_args_list] + + # Analyze event types + event_kinds = [event.get('kind') for event in captured_events if event.get('kind')] + + # Look for any agent messages (should be NONE when streaming disabled) + agent_messages: list[Message] = [] + for event in captured_events: + if event.get('kind') == 'message' and event.get('role') == 'agent': + agent_messages.append(cast(Message, event)) + + # Verify expected behavior: only status updates, NO incremental agent messages + assert 'status-update' in event_kinds, 'Should have status updates' + assert len(agent_messages) == 0, ( + f'Should have NO agent messages when streaming disabled, but got: {agent_messages}' + ) + + # Verify clean event stream: only status updates (working -> completed) + status_events = [event for event in captured_events if event.get('kind') == 'status-update'] + assert len(status_events) >= 2, 'Should have at least working and completed status updates' + + # Verify final result is complete and correct + assert 'artifacts' in final_result + artifacts = final_result['artifacts'] + assert len(artifacts) == 1 + artifact = cast(dict[str, Any], artifacts[0]) + assert artifact['name'] == 'result' + assert len(artifact['parts']) == 1 + artifact_part = cast(dict[str, Any], artifact['parts'][0]) + assert artifact_part['kind'] == 'text' + # Final result should contain all text parts concatenated + assert 'First part of response' in artifact_part['text']