Skip to content

Add Streaming Support to A2A #2362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 105 additions & 20 deletions pydantic_ai_slim/pydantic_ai/_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
Message,
Part,
Skill,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskSendParams,
TaskStatusUpdateEvent,
TextPart as A2ATextPart,
)
from fasta2a.storage import InMemoryStorage, Storage
Expand Down Expand Up @@ -72,6 +74,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT,
def agent_to_a2a(
agent: Agent[AgentDepsT, OutputDataT],
*,
enable_streaming: bool = False,
storage: Storage | None = None,
broker: Broker | None = None,
# Agent card
Expand All @@ -91,7 +94,7 @@ def agent_to_a2a(
"""Create a FastA2A server from an agent."""
storage = storage or InMemoryStorage()
broker = broker or InMemoryBroker()
worker = AgentWorker(agent=agent, broker=broker, storage=storage)
worker = AgentWorker(agent=agent, broker=broker, storage=storage, enable_streaming=enable_streaming)

lifespan = lifespan or partial(worker_lifespan, worker=worker, agent=agent)

Expand All @@ -104,6 +107,7 @@ def agent_to_a2a(
description=description,
provider=provider,
skills=skills,
streaming=enable_streaming,
debug=debug,
routes=routes,
middleware=middleware,
Expand All @@ -117,6 +121,7 @@ class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]
"""A worker that uses an agent to execute tasks."""

agent: Agent[AgentDepsT, WorkerOutputT]
enable_streaming: bool = False

async def run_task(self, params: TaskSendParams) -> None:
task = await self.storage.load_task(params['id'])
Expand All @@ -132,39 +137,119 @@ async def run_task(self, params: TaskSendParams) -> None:

await self.storage.update_task(task['id'], state='working')

# Send working status streaming event
await self.broker.send_stream_event(
task['id'],
TaskStatusUpdateEvent(
task_id=task['id'],
context_id=task['context_id'],
kind='status-update',
status={'state': 'working'},
final=False,
),
)

# Load context - contains pydantic-ai message history from previous tasks in this conversation
message_history = await self.storage.load_context(task['context_id']) or []
message_history.extend(self.build_message_history(task.get('history', [])))

try:
result = await self.agent.run(message_history=message_history) # type: ignore

await self.storage.update_context(task['context_id'], result.all_messages())

# Convert new messages to A2A format for task history
a2a_messages: list[Message] = []

for message in result.new_messages():
if isinstance(message, ModelRequest):
# Skip user prompts - they're already in task history
continue
else:
# Convert response parts to A2A format
a2a_parts = self._response_parts_to_a2a(message.parts)
if a2a_parts: # Add if there are visible parts (text/thinking)
a2a_messages.append(
Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4()))
# Stream processing with agent.iter()
async with self.agent.iter(message_history=message_history, deps=None) as run: # type: ignore
node = run.next_node
while not self.agent.is_end_node(node):
# Check if this node has a model response
if hasattr(node, 'model_response'):
model_response = getattr(node, 'model_response')
# Convert model response parts to A2A parts
a2a_parts = self._response_parts_to_a2a(model_response.parts)

if a2a_parts and self.enable_streaming:
# Send incremental message event with unique ID
incremental_message = Message(
role='agent',
parts=a2a_parts,
kind='message',
message_id=str(uuid.uuid4()), # Generate unique ID per message
)
# Stream the incremental message
await self.broker.send_stream_event(task['id'], incremental_message)

# Move to next node
current = node
node = await run.next(current)

if self.enable_streaming:
# Update context with current messages after each step
await self.storage.update_context(task['context_id'], run.ctx.state.message_history)

# Run finished - get the final result and update context with final messages
assert run.result is not None # Agent iteration should always produce a result
await self.storage.update_context(task['context_id'], run.result.all_messages())

# Convert new messages to A2A format for task history
a2a_messages: list[Message] = []

for message in run.result.new_messages():
if isinstance(message, ModelRequest):
# Skip user prompts - they're already in task history
continue
else:
# Convert response parts to A2A format
a2a_parts = self._response_parts_to_a2a(message.parts)
if a2a_parts: # Add if there are visible parts (text/thinking)
a2a_messages.append(
Message(role='agent', parts=a2a_parts, kind='message', message_id=str(uuid.uuid4()))
)

# Handle final result and create artifacts using build_artifacts method
artifacts = self.build_artifacts(run.result.output)

# Send artifact update events for all artifacts (only for structured outputs)
if not isinstance(run.result.output, str):
for artifact in artifacts:
await self.broker.send_stream_event(
task['id'],
TaskArtifactUpdateEvent(
task_id=task['id'],
context_id=task['context_id'],
kind='artifact-update',
artifact=artifact,
last_chunk=True,
),
)

artifacts = self.build_artifacts(result.output)
except Exception:
await self.storage.update_task(task['id'], state='failed')

# Send failure status streaming event
await self.broker.send_stream_event(
task['id'],
TaskStatusUpdateEvent(
task_id=task['id'],
context_id=task['context_id'],
kind='status-update',
status={'state': 'failed'},
final=True,
),
)
raise
else:
await self.storage.update_task(
task['id'], state='completed', new_artifacts=artifacts, new_messages=a2a_messages
)

# Send completion status streaming event
await self.broker.send_stream_event(
task['id'],
TaskStatusUpdateEvent(
task_id=task['id'],
context_id=task['context_id'],
kind='status-update',
status={'state': 'completed'},
final=True,
),
)

async def cancel_task(self, params: TaskIdParams) -> None:
pass

Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,7 @@ def to_ag_ui(
def to_a2a(
self,
*,
enable_streaming: bool = False,
storage: Storage | None = None,
broker: Broker | None = None,
# Agent card
Expand Down Expand Up @@ -1988,6 +1989,7 @@ def to_a2a(

return agent_to_a2a(
self,
enable_streaming=enable_streaming,
storage=storage,
broker=broker,
name=name,
Expand Down
179 changes: 179 additions & 0 deletions tests/test_a2a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from typing import Any, cast

import anyio
import httpx
Expand All @@ -24,6 +25,7 @@
from .conftest import IsDatetime, IsStr, try_import

with try_import() as imports_successful:
from fasta2a.broker import StreamEvent
from fasta2a.client import A2AClient
from fasta2a.schema import DataPart, FilePart, Message, TextPart
from fasta2a.storage import InMemoryStorage
Expand Down Expand Up @@ -991,3 +993,180 @@ async def test_a2a_multiple_send_task_messages():
],
}
)


async def test_streaming_emits_incremental_messages(mocker: Any) -> None:
"""Verify that enable_streaming=True produces incremental messages during agent execution."""
from fasta2a.broker import InMemoryBroker

# Create a model that produces multiple text parts to simulate streaming
def return_multiple_text_parts(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(
parts=[
PydanticAITextPart(content='First part of response'),
PydanticAITextPart(content='Second part of response'),
PydanticAITextPart(content='Final part'),
]
)

streaming_model = FunctionModel(return_multiple_text_parts)

# Create agent with streaming enabled
agent = Agent(model=streaming_model, output_type=str)
storage = InMemoryStorage()
broker = InMemoryBroker()

# Spy on the broker's send_stream_event method to capture calls
mock_send: Any = mocker.spy(broker, 'send_stream_event')

app = agent.to_a2a(enable_streaming=True, storage=storage, broker=broker)

async with LifespanManager(app):
transport = httpx.ASGITransport(app)
async with httpx.AsyncClient(transport=transport) as http_client:
a2a_client = A2AClient(http_client=http_client)

message = Message(
role='user',
parts=[TextPart(text='Hello, world!', kind='text')],
kind='message',
message_id=str(uuid.uuid4()),
)
response = await a2a_client.send_message(message=message)
assert 'error' not in response
assert 'result' in response

result = response['result']
assert result['kind'] == 'task'
task_id = result['id']

# Wait for task completion
await anyio.sleep(0.1)
final_response = await a2a_client.get_task(task_id)
assert 'result' in final_response
final_result = final_response['result']
assert final_result['status']['state'] == 'completed'

# Verify streaming events were captured
assert mock_send.call_count > 0

# Extract events from mock calls
captured_events: list[StreamEvent] = [call.args[1] for call in mock_send.call_args_list]

# Check that we got different types of events
event_kinds = [event.get('kind') for event in captured_events if event.get('kind')]

# Look for agent messages
message_events: list[Message] = []
for event in captured_events:
if event.get('kind') == 'message' and event.get('role') == 'agent':
message_events.append(cast(Message, event))

# Should have status updates at minimum
assert 'status-update' in event_kinds

# Verify we got at least one agent message during streaming
assert len(message_events) > 0, f'Expected agent messages during streaming, got events: {event_kinds}'

# Verify the agent message contains the expected content
agent_message = message_events[0]
assert agent_message['role'] == 'agent'
assert agent_message['kind'] == 'message'
assert 'message_id' in agent_message
assert 'parts' in agent_message
parts = agent_message['parts']
assert len(parts) == 3 # Should have 3 text parts
first_part = parts[0]
assert first_part.get('kind') == 'text'
assert first_part.get('text') == 'First part of response'


async def test_streaming_disabled_sends_only_final_results(mocker: Any) -> None:
"""Verify enable_streaming=False sends only status updates and final results, no incremental messages."""
from fasta2a.broker import InMemoryBroker

# Create a model that produces multiple text parts - same as streaming test for comparison
def return_multiple_text_parts(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(
parts=[
PydanticAITextPart(content='First part of response'),
PydanticAITextPart(content='Second part of response'),
PydanticAITextPart(content='Final part'),
]
)

streaming_model = FunctionModel(return_multiple_text_parts)

# Create agent with streaming DISABLED (default behavior)
agent = Agent(model=streaming_model, output_type=str)
storage = InMemoryStorage()
broker = InMemoryBroker()

# Spy on the broker's send_stream_event method to capture calls
mock_send: Any = mocker.spy(broker, 'send_stream_event')

# Note: enable_streaming defaults to False, but being explicit for clarity
app = agent.to_a2a(enable_streaming=False, storage=storage, broker=broker)

async with LifespanManager(app):
transport = httpx.ASGITransport(app)
async with httpx.AsyncClient(transport=transport) as http_client:
a2a_client = A2AClient(http_client=http_client)

message = Message(
role='user',
parts=[TextPart(text='Hello, world!', kind='text')],
kind='message',
message_id=str(uuid.uuid4()),
)
response = await a2a_client.send_message(message=message)
assert 'error' not in response
assert 'result' in response

result = response['result']
assert result['kind'] == 'task'
task_id = result['id']

# Wait for task completion
await anyio.sleep(0.1)
final_response = await a2a_client.get_task(task_id)
assert 'result' in final_response
final_result = final_response['result']
assert final_result['status']['state'] == 'completed'

# Verify streaming events were captured
assert mock_send.call_count > 0

# Extract events from mock calls
captured_events: list[StreamEvent] = [call.args[1] for call in mock_send.call_args_list]

# Analyze event types
event_kinds = [event.get('kind') for event in captured_events if event.get('kind')]

# Look for any agent messages (should be NONE when streaming disabled)
agent_messages: list[Message] = []
for event in captured_events:
if event.get('kind') == 'message' and event.get('role') == 'agent':
agent_messages.append(cast(Message, event))

# Verify expected behavior: only status updates, NO incremental agent messages
assert 'status-update' in event_kinds, 'Should have status updates'
assert len(agent_messages) == 0, (
f'Should have NO agent messages when streaming disabled, but got: {agent_messages}'
)

# Verify clean event stream: only status updates (working -> completed)
status_events = [event for event in captured_events if event.get('kind') == 'status-update']
assert len(status_events) >= 2, 'Should have at least working and completed status updates'

# Verify final result is complete and correct
assert 'artifacts' in final_result
artifacts = final_result['artifacts']
assert len(artifacts) == 1
artifact = cast(dict[str, Any], artifacts[0])
assert artifact['name'] == 'result'
assert len(artifact['parts']) == 1
artifact_part = cast(dict[str, Any], artifact['parts'][0])
assert artifact_part['kind'] == 'text'
# Final result should contain all text parts concatenated
assert 'First part of response' in artifact_part['text']
Loading