Skip to content

Commit 5f06b68

Browse files
Python: handle streamed A2A update events (#4919)
* Python: handle streamed A2A update events * Python: preserve terminal A2A artifacts during streaming * Python: harden streamed A2A update event handling * Python: simplify streamed A2A update guard --------- Co-authored-by: sztoplover-bit <253473756+sztoplover-bit@users.noreply.github.com> Co-authored-by: Giles Odigwe <79032838+giles17@users.noreply.github.com>
1 parent 524c021 commit 5f06b68

2 files changed

Lines changed: 266 additions & 2 deletions

File tree

python/packages/a2a/agent_framework_a2a/_agent.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ async def _map_a2a_stream(
365365
)
366366

367367
all_updates: list[AgentResponseUpdate] = []
368+
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
368369
async for item in a2a_stream:
369370
if isinstance(item, A2AMessage):
370371
# Process A2A Message
@@ -378,12 +379,21 @@ async def _map_a2a_stream(
378379
all_updates.append(update)
379380
yield update
380381
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
381-
task, _update_event = item
382-
for update in self._updates_from_task(
382+
task, update_event = item
383+
updates = self._updates_from_task(
383384
task,
385+
update_event=update_event,
384386
background=background,
385387
emit_intermediate=emit_intermediate,
388+
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
389+
)
390+
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
391+
update.raw_representation is update_event for update in updates
386392
):
393+
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
394+
if task.status.state in TERMINAL_TASK_STATES:
395+
streamed_artifact_ids_by_task.pop(task.id, None)
396+
for update in updates:
387397
all_updates.append(update)
388398
yield update
389399
else:
@@ -403,8 +413,10 @@ def _updates_from_task(
403413
self,
404414
task: Task,
405415
*,
416+
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
406417
background: bool = False,
407418
emit_intermediate: bool = False,
419+
streamed_artifact_ids: set[str] | None = None,
408420
) -> list[AgentResponseUpdate]:
409421
"""Convert an A2A Task into AgentResponseUpdate(s).
410422
@@ -418,8 +430,21 @@ def _updates_from_task(
418430
"""
419431
status = task.status
420432

433+
if (
434+
emit_intermediate
435+
and update_event is not None
436+
and (event_updates := self._updates_from_task_update_event(update_event))
437+
):
438+
return event_updates
439+
421440
if status.state in TERMINAL_TASK_STATES:
422441
task_messages = self._parse_messages_from_task(task)
442+
if task.artifacts is not None and streamed_artifact_ids:
443+
task_messages = [
444+
message
445+
for message in task_messages
446+
if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids
447+
]
423448
if task_messages:
424449
return [
425450
AgentResponseUpdate(
@@ -431,6 +456,8 @@ def _updates_from_task(
431456
)
432457
for message in task_messages
433458
]
459+
if task.artifacts is not None:
460+
return []
434461
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
435462

436463
if background and status.state in IN_PROGRESS_TASK_STATES:
@@ -467,6 +494,44 @@ def _updates_from_task(
467494

468495
return []
469496

497+
def _updates_from_task_update_event(
498+
self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent
499+
) -> list[AgentResponseUpdate]:
500+
"""Convert A2A task update events into streaming AgentResponseUpdates."""
501+
if isinstance(update_event, TaskArtifactUpdateEvent):
502+
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
503+
if not contents:
504+
return []
505+
return [
506+
AgentResponseUpdate(
507+
contents=contents,
508+
role="assistant",
509+
response_id=update_event.task_id,
510+
message_id=update_event.artifact.artifact_id,
511+
raw_representation=update_event,
512+
)
513+
]
514+
515+
if not isinstance(update_event, TaskStatusUpdateEvent):
516+
return []
517+
518+
message = update_event.status.message
519+
if message is None or not message.parts:
520+
return []
521+
522+
contents = self._parse_contents_from_a2a(message.parts)
523+
if not contents:
524+
return []
525+
526+
return [
527+
AgentResponseUpdate(
528+
contents=contents,
529+
role="assistant" if message.role == A2ARole.agent else "user",
530+
response_id=update_event.task_id,
531+
raw_representation=update_event,
532+
)
533+
]
534+
470535
@staticmethod
471536
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
472537
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""

python/packages/a2a/tests/test_a2a_agent.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
FileWithUri,
1515
Part,
1616
Task,
17+
TaskArtifactUpdateEvent,
1718
TaskState,
1819
TaskStatus,
20+
TaskStatusUpdateEvent,
1921
TextPart,
2022
)
2123
from a2a.types import Message as A2AMessage
@@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped(
11891191
assert updates[0].contents[0].text == "Result"
11901192

11911193

1194+
async def test_streaming_artifact_update_event_yields_content(
1195+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1196+
) -> None:
1197+
"""Test that streaming artifact update events yield incremental content."""
1198+
task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None))
1199+
artifact = Artifact(
1200+
artifact_id="artifact-1",
1201+
parts=[Part(root=TextPart(text="Hello"))],
1202+
)
1203+
update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False)
1204+
mock_a2a_client.responses.append((task, update_event))
1205+
1206+
updates: list[AgentResponseUpdate] = []
1207+
async for update in a2a_agent.run("Hello", stream=True):
1208+
updates.append(update)
1209+
1210+
assert len(updates) == 1
1211+
assert updates[0].text == "Hello"
1212+
assert updates[0].message_id == "artifact-1"
1213+
assert updates[0].raw_representation == update_event
1214+
1215+
1216+
async def test_streaming_status_update_event_yields_content(
1217+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1218+
) -> None:
1219+
"""Test that streaming status update events surface message content directly from the update event."""
1220+
update_event = TaskStatusUpdateEvent(
1221+
task_id="task-status",
1222+
context_id="ctx-status",
1223+
status=TaskStatus(
1224+
state=TaskState.working,
1225+
message=A2AMessage(
1226+
message_id=str(uuid4()),
1227+
role=A2ARole.agent,
1228+
parts=[Part(root=TextPart(text="Still working"))],
1229+
),
1230+
),
1231+
final=False,
1232+
)
1233+
task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None))
1234+
mock_a2a_client.responses.append((task, update_event))
1235+
1236+
updates: list[AgentResponseUpdate] = []
1237+
async for update in a2a_agent.run("Hello", stream=True):
1238+
updates.append(update)
1239+
1240+
assert len(updates) == 1
1241+
assert updates[0].text == "Still working"
1242+
assert updates[0].role == "assistant"
1243+
assert updates[0].raw_representation == update_event
1244+
1245+
1246+
async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
1247+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1248+
) -> None:
1249+
"""Test that streamed artifact chunks are not re-emitted from the final terminal task."""
1250+
working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working))
1251+
first_chunk = TaskArtifactUpdateEvent(
1252+
task_id="task-art-dup",
1253+
context_id="ctx-art-dup",
1254+
artifact=Artifact(
1255+
artifact_id="artifact-dup",
1256+
parts=[Part(root=TextPart(text="Hello "))],
1257+
),
1258+
append=False,
1259+
)
1260+
second_chunk = TaskArtifactUpdateEvent(
1261+
task_id="task-art-dup",
1262+
context_id="ctx-art-dup",
1263+
artifact=Artifact(
1264+
artifact_id="artifact-dup",
1265+
parts=[Part(root=TextPart(text="world"))],
1266+
),
1267+
append=True,
1268+
)
1269+
terminal_task = Task(
1270+
id="task-art-dup",
1271+
context_id="ctx-art-dup",
1272+
status=TaskStatus(state=TaskState.completed, message=None),
1273+
artifacts=[
1274+
Artifact(
1275+
artifact_id="artifact-dup",
1276+
parts=[Part(root=TextPart(text="Hello world"))],
1277+
)
1278+
],
1279+
)
1280+
terminal_event = TaskStatusUpdateEvent(
1281+
task_id="task-art-dup",
1282+
context_id="ctx-art-dup",
1283+
status=TaskStatus(state=TaskState.completed, message=None),
1284+
final=True,
1285+
)
1286+
1287+
mock_a2a_client.responses.extend(
1288+
[
1289+
(working_task, first_chunk),
1290+
(working_task, second_chunk),
1291+
(terminal_task, terminal_event),
1292+
]
1293+
)
1294+
1295+
stream = a2a_agent.run("Hello", stream=True)
1296+
updates: list[AgentResponseUpdate] = []
1297+
async for update in stream:
1298+
updates.append(update)
1299+
response = await stream.get_final_response()
1300+
1301+
assert [update.text for update in updates] == ["Hello ", "world"]
1302+
assert response.text == "Hello world"
1303+
assert len(response.messages) == 1
1304+
1305+
1306+
async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content(
1307+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1308+
) -> None:
1309+
"""Test that terminal task artifacts are still emitted when the final status event has no message."""
1310+
terminal_task = Task(
1311+
id="task-art-final",
1312+
context_id="ctx-art-final",
1313+
status=TaskStatus(state=TaskState.completed, message=None),
1314+
artifacts=[
1315+
Artifact(
1316+
artifact_id="artifact-final",
1317+
parts=[Part(root=TextPart(text="Final artifact"))],
1318+
)
1319+
],
1320+
)
1321+
terminal_event = TaskStatusUpdateEvent(
1322+
task_id="task-art-final",
1323+
context_id="ctx-art-final",
1324+
status=TaskStatus(state=TaskState.completed, message=None),
1325+
final=True,
1326+
)
1327+
mock_a2a_client.responses.append((terminal_task, terminal_event))
1328+
1329+
updates: list[AgentResponseUpdate] = []
1330+
async for update in a2a_agent.run("Hello", stream=True):
1331+
updates.append(update)
1332+
1333+
assert len(updates) == 1
1334+
assert updates[0].text == "Final artifact"
1335+
assert updates[0].message_id == "artifact-final"
1336+
1337+
1338+
async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
1339+
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
1340+
) -> None:
1341+
"""Test that the terminal task only emits artifacts that were not already streamed incrementally."""
1342+
working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working))
1343+
streamed_chunk = TaskArtifactUpdateEvent(
1344+
task_id="task-art-mixed",
1345+
context_id="ctx-art-mixed",
1346+
artifact=Artifact(
1347+
artifact_id="artifact-streamed",
1348+
parts=[Part(root=TextPart(text="Hello"))],
1349+
),
1350+
append=False,
1351+
)
1352+
terminal_task = Task(
1353+
id="task-art-mixed",
1354+
context_id="ctx-art-mixed",
1355+
status=TaskStatus(state=TaskState.completed, message=None),
1356+
artifacts=[
1357+
Artifact(
1358+
artifact_id="artifact-streamed",
1359+
parts=[Part(root=TextPart(text="Hello"))],
1360+
),
1361+
Artifact(
1362+
artifact_id="artifact-final",
1363+
parts=[Part(root=TextPart(text="Goodbye"))],
1364+
),
1365+
],
1366+
)
1367+
terminal_event = TaskStatusUpdateEvent(
1368+
task_id="task-art-mixed",
1369+
context_id="ctx-art-mixed",
1370+
status=TaskStatus(state=TaskState.completed, message=None),
1371+
final=True,
1372+
)
1373+
1374+
mock_a2a_client.responses.extend(
1375+
[
1376+
(working_task, streamed_chunk),
1377+
(terminal_task, terminal_event),
1378+
]
1379+
)
1380+
1381+
stream = a2a_agent.run("Hello", stream=True)
1382+
updates: list[AgentResponseUpdate] = []
1383+
async for update in stream:
1384+
updates.append(update)
1385+
response = await stream.get_final_response()
1386+
1387+
assert [update.text for update in updates] == ["Hello", "Goodbye"]
1388+
assert [message.text for message in response.messages] == ["Hello", "Goodbye"]
1389+
1390+
11921391
# endregion

0 commit comments

Comments
 (0)