Skip to content

Commit 5b32e75

Browse files
authored
Merge pull request #673 from cnoe-io/feat/a2a-source-agent-tracking
feat(a2a): add source agent tracking for sub-agent message grouping
2 parents 9df0ae3 + 278c704 commit 5b32e75

File tree

13 files changed

+1011
-141
lines changed

13 files changed

+1011
-141
lines changed

ai_platform_engineering/multi_agents/platform_engineer/protocol_bindings/a2a/agent.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
268268
# Format: {tool_name: response_content}
269269
accumulated_subagent_responses = {}
270270

271+
# Track current active agent for sub-agent message grouping
272+
# This is used by the executor to add sourceAgent metadata to artifacts
273+
current_agent: str | None = None
274+
271275
# Check if token-by-token streaming is enabled (default: true)
272276
# When disabled, uses 'values' mode which waits for complete messages
273277
enable_streaming = os.getenv("ENABLE_STREAMING", "true").lower() == "true"
@@ -343,6 +347,8 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
343347
logging.debug("Skipping tool call with empty name (streaming chunk)")
344348
continue
345349

350+
# Track current agent for sub-agent message grouping
351+
current_agent = tool_name
346352
logging.debug(f"Tool call started (from AIMessageChunk): {tool_name}")
347353

348354
# Stream tool start notification to client with metadata
@@ -351,6 +357,7 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
351357
"is_task_complete": False,
352358
"require_user_input": False,
353359
"content": f"🔧 Supervisor: Calling Agent {tool_name_formatted}...\n",
360+
"source_agent": tool_name,
354361
"tool_call": {
355362
"name": tool_name,
356363
"status": "started",
@@ -388,12 +395,14 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
388395
if match:
389396
agent_name = match.group(1)
390397
purpose = match.group(2)
398+
current_agent = agent_name.lower() # Update current agent
391399
logging.debug(f"Tool update detected: {agent_name} - {purpose}")
392400
# Emit as tool_update event
393401
yield {
394402
"is_task_complete": False,
395403
"require_user_input": False,
396404
"content": content,
405+
"source_agent": current_agent,
397406
"tool_update": {
398407
"name": agent_name.lower(),
399408
"purpose": purpose,
@@ -402,11 +411,12 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
402411
}
403412
}
404413
else:
405-
# Regular content - no special handling
414+
# Regular content - include source_agent for grouping
406415
yield {
407416
"is_task_complete": False,
408417
"require_user_input": False,
409418
"content": content,
419+
"source_agent": current_agent or "supervisor",
410420
}
411421

412422
# Handle AIMessage with tool calls (tool start indicators)
@@ -425,6 +435,8 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
425435
pending_tool_calls[tool_call_id] = tool_name
426436
logging.debug(f"Tracked tool call: {tool_call_id} -> {tool_name}")
427437

438+
# Track current agent for sub-agent message grouping
439+
current_agent = tool_name
428440
logging.info(f"Tool call started: {tool_name}")
429441

430442
# Stream tool start notification to client with metadata
@@ -433,6 +445,7 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
433445
"is_task_complete": False,
434446
"require_user_input": False,
435447
"content": f"🔧 Supervisor: Calling Agent {tool_name_formatted}...\n",
448+
"source_agent": tool_name,
436449
"tool_call": {
437450
"name": tool_name,
438451
"status": "started",
@@ -494,6 +507,7 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
494507
yield {
495508
"is_task_complete": False,
496509
"require_user_input": False,
510+
"source_agent": "supervisor",
497511
"artifact": {
498512
"name": "execution_plan_update",
499513
"description": "TODO-based execution plan",
@@ -507,6 +521,7 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
507521
yield {
508522
"is_task_complete": False,
509523
"require_user_input": False,
524+
"source_agent": "supervisor",
510525
"artifact": {
511526
"name": "execution_plan_status_update",
512527
"description": "TODO progress update",
@@ -557,17 +572,19 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
557572
logging.warning(f"Failed to parse request_user_input content: {e}")
558573
# Fall through to normal handling if parsing fails
559574
elif tool_name in rag_tool_names:
560-
# For RAG tools, we don't want to stream the content, as its a LOT of text
575+
# For RAG tools, we don't want to stream the content, as its a LOT of text
561576
yield {
562577
"is_task_complete": False,
563578
"require_user_input": False,
579+
"source_agent": tool_name,
564580
"content": f"🔍 {tool_name}...",
565581
}
566582
# Stream other tool content normally (actual results for user)
567583
elif tool_content and tool_content.strip():
568584
yield {
569585
"is_task_complete": False,
570586
"require_user_input": False,
587+
"source_agent": tool_name,
571588
"content": tool_content + "\n",
572589
}
573590

@@ -576,6 +593,7 @@ async def stream(self, query, context_id, trace_id=None) -> AsyncIterable[dict[s
576593
yield {
577594
"is_task_complete": False,
578595
"require_user_input": False,
596+
"source_agent": tool_name,
579597
"content": f"✅ Supervisor: Agent task {tool_name_formatted} completed\n",
580598
"tool_result": {
581599
"name": tool_name,

ai_platform_engineering/multi_agents/platform_engineer/protocol_bindings/a2a/agent_executor.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class StreamState:
5959
task_complete: bool = False
6060
user_input_required: bool = False
6161

62+
# Source agent tracking for sub-agent message grouping
63+
current_agent: Optional[str] = None
64+
agent_streaming_artifact_ids: Dict[str, str] = field(default_factory=dict)
65+
6266

6367
class AIPlatformEngineerA2AExecutor(AgentExecutor):
6468
"""AI Platform Engineer A2A Executor."""
@@ -158,7 +162,7 @@ def _extract_final_answer(self, content: str) -> str:
158162
# Extract everything after the marker
159163
idx = content.find(marker)
160164
final_content = content[idx + len(marker):].strip()
161-
logger.debug(f"Extracted final answer: {len(final_content)} chars (marker found at pos {idx})")
165+
logger.info(f"Extracted final answer: {len(final_content)} chars (marker found at pos {idx})")
162166
return final_content
163167
return content
164168

@@ -259,19 +263,44 @@ def _normalize_content(self, content) -> str:
259263
async def _send_artifact(self, event_queue: EventQueue, task: A2ATask,
260264
artifact: Artifact, append: bool, last_chunk: bool = False):
261265
"""Send an artifact update event."""
266+
# Debug: Log artifact being sent
267+
artifact_name = getattr(artifact, 'name', 'unknown')
268+
# A2A stores text in parts[0].text, not in a top-level text attribute
269+
parts = getattr(artifact, 'parts', [])
270+
parts_text = None
271+
if parts:
272+
# Try different ways to access text
273+
first_part = parts[0]
274+
if hasattr(first_part, 'text'):
275+
parts_text = first_part.text
276+
elif hasattr(first_part, 'root') and hasattr(first_part.root, 'text'):
277+
parts_text = first_part.root.text
278+
elif isinstance(first_part, dict):
279+
parts_text = first_part.get('text')
280+
text_preview = parts_text[:100] if parts_text else '(no parts.text)'
281+
text_len = len(parts_text) if parts_text else 0
282+
283+
if artifact_name in ('final_result', 'partial_result'):
284+
logger.info(f"📤 FINAL ARTIFACT: parts_count={len(parts)}, text_len={text_len}")
285+
logger.info(f"📤 FINAL ARTIFACT preview: {text_preview}...")
286+
# Debug: Log the actual artifact structure
287+
logger.info(f"📤 FINAL ARTIFACT parts[0] type: {type(parts[0]) if parts else 'NO_PARTS'}")
288+
logger.info(f"📤 FINAL ARTIFACT parts[0] attrs: {dir(parts[0]) if parts else 'NO_PARTS'}")
289+
262290
await self._safe_enqueue_event(
263291
event_queue,
264292
TaskArtifactUpdateEvent(
265293
append=append,
266294
context_id=task.context_id,
267295
task_id=task.id,
268-
last_chunk=last_chunk,
296+
lastChunk=last_chunk,
269297
artifact=artifact,
270298
)
271299
)
272300

273301
async def _send_completion(self, event_queue: EventQueue, task: A2ATask):
274302
"""Send task completion status."""
303+
logger.info(f"📤 Sending completion status for task {task.id}")
275304
await self._safe_enqueue_event(
276305
event_queue,
277306
TaskStatusUpdateEvent(
@@ -281,6 +310,7 @@ async def _send_completion(self, event_queue: EventQueue, task: A2ATask):
281310
task_id=task.id,
282311
)
283312
)
313+
logger.info(f"📤 Completion status enqueued for task {task.id}")
284314

285315
async def _send_error(self, event_queue: EventQueue, task: A2ATask, error_msg: str):
286316
"""Send task failure status."""
@@ -312,10 +342,20 @@ async def _handle_sub_agent_artifact(self, event: dict, state: StreamState,
312342
artifact_name = artifact_data.get('name', 'streaming_result')
313343
parts = artifact_data.get('parts', [])
314344

345+
# Extract sourceAgent from artifact metadata, event, or current state
346+
existing_metadata = artifact_data.get('metadata', {})
347+
source_agent = (
348+
existing_metadata.get('sourceAgent') or
349+
event.get('source_agent') or
350+
state.current_agent or
351+
'sub-agent'
352+
)
353+
logger.debug(f"📦 Sub-agent artifact from: {source_agent}")
354+
315355
# Accumulate final results (complete_result, final_result, partial_result)
316356
if artifact_name in ('complete_result', 'final_result', 'partial_result'):
317357
state.sub_agents_completed += 1
318-
logger.debug(f"Sub-agent completed with {artifact_name} (total completed: {state.sub_agents_completed})")
358+
logger.info(f"Sub-agent completed with {artifact_name} (total completed: {state.sub_agents_completed})")
319359

320360
for part in parts:
321361
if isinstance(part, dict):
@@ -335,11 +375,17 @@ async def _handle_sub_agent_artifact(self, event: dict, state: StreamState,
335375
elif part.get('data'):
336376
artifact_parts.append(Part(root=DataPart(data=part['data'])))
337377

378+
# Create artifact with sourceAgent metadata for sub-agent message grouping
338379
artifact = Artifact(
339380
artifactId=artifact_data.get('artifactId'),
340381
name=artifact_name,
341-
description=artifact_data.get('description', 'From sub-agent'),
342-
parts=artifact_parts
382+
description=artifact_data.get('description', f'From {source_agent}'),
383+
parts=artifact_parts,
384+
metadata={
385+
'sourceAgent': source_agent,
386+
'agentType': 'sub-agent',
387+
**existing_metadata # Preserve any other metadata
388+
}
343389
)
344390

345391
# Track artifact ID for append logic
@@ -474,14 +520,28 @@ async def _handle_streaming_chunk(self, event: dict, state: StreamState,
474520

475521
is_tool_notification = self._is_tool_notification(content, event)
476522

523+
# Track current agent from tool_call events for sub-agent message grouping
524+
if 'tool_call' in event:
525+
tool_name = event['tool_call'].get('name', 'unknown')
526+
state.current_agent = tool_name
527+
logger.info(f"🎯 Current agent set to: {tool_name}")
528+
elif 'tool_result' in event:
529+
# Tool completed - keep current agent for any remaining content
530+
tool_name = event['tool_result'].get('name', state.current_agent)
531+
logger.info(f"✅ Tool completed: {tool_name}")
532+
533+
# Also detect agent from event metadata if provided
534+
source_agent = event.get('source_agent') or state.current_agent or 'supervisor'
535+
477536
# Accumulate non-notification content (unless DataPart already received)
478537
if not is_tool_notification and not state.sub_agent_datapart:
479538
state.supervisor_content.append(content)
480539

481-
# Create artifact
540+
# Create artifact with sourceAgent metadata
482541
if is_tool_notification:
483542
artifact_name, description = self._get_artifact_name_for_notification(content, event)
484543
artifact = new_text_artifact(name=artifact_name, description=description, text=content)
544+
artifact.metadata = {'sourceAgent': source_agent, 'agentType': 'notification'}
485545
use_append = False
486546
state.seen_artifact_ids.add(artifact.artifact_id)
487547
elif state.streaming_artifact_id is None:
@@ -491,6 +551,7 @@ async def _handle_streaming_chunk(self, event: dict, state: StreamState,
491551
description='Streaming result',
492552
text=content,
493553
)
554+
artifact.metadata = {'sourceAgent': source_agent, 'agentType': 'streaming'}
494555
state.streaming_artifact_id = artifact.artifact_id
495556
state.seen_artifact_ids.add(artifact.artifact_id)
496557
state.first_artifact_sent = True
@@ -503,6 +564,7 @@ async def _handle_streaming_chunk(self, event: dict, state: StreamState,
503564
text=content,
504565
)
505566
artifact.artifact_id = state.streaming_artifact_id
567+
artifact.metadata = {'sourceAgent': source_agent, 'agentType': 'streaming'}
506568
use_append = True
507569

508570
await self._send_artifact(event_queue, task, artifact, append=use_append)
@@ -514,7 +576,13 @@ async def _handle_stream_end(self, state: StreamState, task: A2ATask,
514576
# For single-agent scenarios where sub-agent already sent complete_result,
515577
# we just need to send the completion status (content already forwarded)
516578

579+
# Debug: Log accumulated content before getting final
580+
logger.info(f"📦 Stream end - supervisor_content: {len(state.supervisor_content)} items, {sum(len(c) for c in state.supervisor_content)} chars")
581+
logger.info(f"📦 Stream end - sub_agent_content: {len(state.sub_agent_content)} items, {sum(len(c) for c in state.sub_agent_content)} chars")
582+
logger.info(f"📦 Stream end - sub_agents_completed: {state.sub_agents_completed}")
583+
517584
final_content, is_datapart = self._get_final_content(state)
585+
logger.info(f"📦 Final content for UI: {len(final_content) if isinstance(final_content, str) else 'datapart'} chars, is_datapart={is_datapart}")
518586

519587
# If we have accumulated content (supervisor synthesis or sub-agent content), send it
520588
if final_content or is_datapart:
@@ -602,6 +670,29 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
602670
artifact=event.artifact
603671
)
604672
await self._safe_enqueue_event(event_queue, transformed)
673+
674+
# 🔧 CRITICAL FIX: Accumulate content from typed artifacts for final_result
675+
# Without this, _get_final_content returns empty and UI never gets final render
676+
artifact = event.artifact
677+
if artifact and hasattr(artifact, 'parts') and artifact.parts:
678+
artifact_name = getattr(artifact, 'name', 'streaming_result')
679+
is_final_artifact = artifact_name in ('complete_result', 'final_result', 'partial_result')
680+
681+
for part in artifact.parts:
682+
part_root = getattr(part, 'root', None)
683+
if part_root and hasattr(part_root, 'text') and part_root.text:
684+
# Accumulate streaming content
685+
if artifact_name == 'streaming_result':
686+
if not self._is_tool_notification(part_root.text, {}):
687+
state.supervisor_content.append(part_root.text)
688+
# Accumulate final results from sub-agents
689+
elif is_final_artifact:
690+
state.sub_agent_content.append(part_root.text)
691+
692+
# Increment sub_agents_completed once per final artifact
693+
if is_final_artifact:
694+
state.sub_agents_completed += 1
695+
logger.info(f"Sub-agent completed via typed event with {artifact_name} (total: {state.sub_agents_completed})")
605696
else:
606697
corrected = TaskStatusUpdateEvent(
607698
context_id=event.context_id,

ai_platform_engineering/multi_agents/platform_engineer/protocol_bindings/a2a/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)