|
14 | 14 | FileWithUri, |
15 | 15 | Part, |
16 | 16 | Task, |
| 17 | + TaskArtifactUpdateEvent, |
17 | 18 | TaskState, |
18 | 19 | TaskStatus, |
| 20 | + TaskStatusUpdateEvent, |
19 | 21 | TextPart, |
20 | 22 | ) |
21 | 23 | from a2a.types import Message as A2AMessage |
@@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped( |
1189 | 1191 | assert updates[0].contents[0].text == "Result" |
1190 | 1192 |
|
1191 | 1193 |
|
| 1194 | +async def test_streaming_artifact_update_event_yields_content( |
| 1195 | + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient |
| 1196 | +) -> None: |
| 1197 | + """Test that streaming artifact update events yield incremental content.""" |
| 1198 | + task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None)) |
| 1199 | + artifact = Artifact( |
| 1200 | + artifact_id="artifact-1", |
| 1201 | + parts=[Part(root=TextPart(text="Hello"))], |
| 1202 | + ) |
| 1203 | + update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False) |
| 1204 | + mock_a2a_client.responses.append((task, update_event)) |
| 1205 | + |
| 1206 | + updates: list[AgentResponseUpdate] = [] |
| 1207 | + async for update in a2a_agent.run("Hello", stream=True): |
| 1208 | + updates.append(update) |
| 1209 | + |
| 1210 | + assert len(updates) == 1 |
| 1211 | + assert updates[0].text == "Hello" |
| 1212 | + assert updates[0].message_id == "artifact-1" |
| 1213 | + assert updates[0].raw_representation == update_event |
| 1214 | + |
| 1215 | + |
| 1216 | +async def test_streaming_status_update_event_yields_content( |
| 1217 | + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient |
| 1218 | +) -> None: |
| 1219 | + """Test that streaming status update events surface message content directly from the update event.""" |
| 1220 | + update_event = TaskStatusUpdateEvent( |
| 1221 | + task_id="task-status", |
| 1222 | + context_id="ctx-status", |
| 1223 | + status=TaskStatus( |
| 1224 | + state=TaskState.working, |
| 1225 | + message=A2AMessage( |
| 1226 | + message_id=str(uuid4()), |
| 1227 | + role=A2ARole.agent, |
| 1228 | + parts=[Part(root=TextPart(text="Still working"))], |
| 1229 | + ), |
| 1230 | + ), |
| 1231 | + final=False, |
| 1232 | + ) |
| 1233 | + task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None)) |
| 1234 | + mock_a2a_client.responses.append((task, update_event)) |
| 1235 | + |
| 1236 | + updates: list[AgentResponseUpdate] = [] |
| 1237 | + async for update in a2a_agent.run("Hello", stream=True): |
| 1238 | + updates.append(update) |
| 1239 | + |
| 1240 | + assert len(updates) == 1 |
| 1241 | + assert updates[0].text == "Still working" |
| 1242 | + assert updates[0].role == "assistant" |
| 1243 | + assert updates[0].raw_representation == update_event |
| 1244 | + |
| 1245 | + |
| 1246 | +async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts( |
| 1247 | + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient |
| 1248 | +) -> None: |
| 1249 | + """Test that streamed artifact chunks are not re-emitted from the final terminal task.""" |
| 1250 | + working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working)) |
| 1251 | + first_chunk = TaskArtifactUpdateEvent( |
| 1252 | + task_id="task-art-dup", |
| 1253 | + context_id="ctx-art-dup", |
| 1254 | + artifact=Artifact( |
| 1255 | + artifact_id="artifact-dup", |
| 1256 | + parts=[Part(root=TextPart(text="Hello "))], |
| 1257 | + ), |
| 1258 | + append=False, |
| 1259 | + ) |
| 1260 | + second_chunk = TaskArtifactUpdateEvent( |
| 1261 | + task_id="task-art-dup", |
| 1262 | + context_id="ctx-art-dup", |
| 1263 | + artifact=Artifact( |
| 1264 | + artifact_id="artifact-dup", |
| 1265 | + parts=[Part(root=TextPart(text="world"))], |
| 1266 | + ), |
| 1267 | + append=True, |
| 1268 | + ) |
| 1269 | + terminal_task = Task( |
| 1270 | + id="task-art-dup", |
| 1271 | + context_id="ctx-art-dup", |
| 1272 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1273 | + artifacts=[ |
| 1274 | + Artifact( |
| 1275 | + artifact_id="artifact-dup", |
| 1276 | + parts=[Part(root=TextPart(text="Hello world"))], |
| 1277 | + ) |
| 1278 | + ], |
| 1279 | + ) |
| 1280 | + terminal_event = TaskStatusUpdateEvent( |
| 1281 | + task_id="task-art-dup", |
| 1282 | + context_id="ctx-art-dup", |
| 1283 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1284 | + final=True, |
| 1285 | + ) |
| 1286 | + |
| 1287 | + mock_a2a_client.responses.extend( |
| 1288 | + [ |
| 1289 | + (working_task, first_chunk), |
| 1290 | + (working_task, second_chunk), |
| 1291 | + (terminal_task, terminal_event), |
| 1292 | + ] |
| 1293 | + ) |
| 1294 | + |
| 1295 | + stream = a2a_agent.run("Hello", stream=True) |
| 1296 | + updates: list[AgentResponseUpdate] = [] |
| 1297 | + async for update in stream: |
| 1298 | + updates.append(update) |
| 1299 | + response = await stream.get_final_response() |
| 1300 | + |
| 1301 | + assert [update.text for update in updates] == ["Hello ", "world"] |
| 1302 | + assert response.text == "Hello world" |
| 1303 | + assert len(response.messages) == 1 |
| 1304 | + |
| 1305 | + |
| 1306 | +async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content( |
| 1307 | + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient |
| 1308 | +) -> None: |
| 1309 | + """Test that terminal task artifacts are still emitted when the final status event has no message.""" |
| 1310 | + terminal_task = Task( |
| 1311 | + id="task-art-final", |
| 1312 | + context_id="ctx-art-final", |
| 1313 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1314 | + artifacts=[ |
| 1315 | + Artifact( |
| 1316 | + artifact_id="artifact-final", |
| 1317 | + parts=[Part(root=TextPart(text="Final artifact"))], |
| 1318 | + ) |
| 1319 | + ], |
| 1320 | + ) |
| 1321 | + terminal_event = TaskStatusUpdateEvent( |
| 1322 | + task_id="task-art-final", |
| 1323 | + context_id="ctx-art-final", |
| 1324 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1325 | + final=True, |
| 1326 | + ) |
| 1327 | + mock_a2a_client.responses.append((terminal_task, terminal_event)) |
| 1328 | + |
| 1329 | + updates: list[AgentResponseUpdate] = [] |
| 1330 | + async for update in a2a_agent.run("Hello", stream=True): |
| 1331 | + updates.append(update) |
| 1332 | + |
| 1333 | + assert len(updates) == 1 |
| 1334 | + assert updates[0].text == "Final artifact" |
| 1335 | + assert updates[0].message_id == "artifact-final" |
| 1336 | + |
| 1337 | + |
| 1338 | +async def test_streaming_terminal_task_only_emits_unstreamed_artifacts( |
| 1339 | + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient |
| 1340 | +) -> None: |
| 1341 | + """Test that the terminal task only emits artifacts that were not already streamed incrementally.""" |
| 1342 | + working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working)) |
| 1343 | + streamed_chunk = TaskArtifactUpdateEvent( |
| 1344 | + task_id="task-art-mixed", |
| 1345 | + context_id="ctx-art-mixed", |
| 1346 | + artifact=Artifact( |
| 1347 | + artifact_id="artifact-streamed", |
| 1348 | + parts=[Part(root=TextPart(text="Hello"))], |
| 1349 | + ), |
| 1350 | + append=False, |
| 1351 | + ) |
| 1352 | + terminal_task = Task( |
| 1353 | + id="task-art-mixed", |
| 1354 | + context_id="ctx-art-mixed", |
| 1355 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1356 | + artifacts=[ |
| 1357 | + Artifact( |
| 1358 | + artifact_id="artifact-streamed", |
| 1359 | + parts=[Part(root=TextPart(text="Hello"))], |
| 1360 | + ), |
| 1361 | + Artifact( |
| 1362 | + artifact_id="artifact-final", |
| 1363 | + parts=[Part(root=TextPart(text="Goodbye"))], |
| 1364 | + ), |
| 1365 | + ], |
| 1366 | + ) |
| 1367 | + terminal_event = TaskStatusUpdateEvent( |
| 1368 | + task_id="task-art-mixed", |
| 1369 | + context_id="ctx-art-mixed", |
| 1370 | + status=TaskStatus(state=TaskState.completed, message=None), |
| 1371 | + final=True, |
| 1372 | + ) |
| 1373 | + |
| 1374 | + mock_a2a_client.responses.extend( |
| 1375 | + [ |
| 1376 | + (working_task, streamed_chunk), |
| 1377 | + (terminal_task, terminal_event), |
| 1378 | + ] |
| 1379 | + ) |
| 1380 | + |
| 1381 | + stream = a2a_agent.run("Hello", stream=True) |
| 1382 | + updates: list[AgentResponseUpdate] = [] |
| 1383 | + async for update in stream: |
| 1384 | + updates.append(update) |
| 1385 | + response = await stream.get_final_response() |
| 1386 | + |
| 1387 | + assert [update.text for update in updates] == ["Hello", "Goodbye"] |
| 1388 | + assert [message.text for message in response.messages] == ["Hello", "Goodbye"] |
| 1389 | + |
| 1390 | + |
1192 | 1391 | # endregion |
0 commit comments