Skip to content

Commit 82bd4f3

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Extract and propagate task_id in RemoteA2aAgent
The RemoteA2aAgent now extracts a "task_id" from the custom metadata of the last agent event in the session, alongside the existing "context_id". This task_id is then included in the A2AMessage sent to the remote A2A service. Close #3765 PiperOrigin-RevId: 840375992
1 parent c557b0a commit 82bd4f3

File tree

2 files changed

+34
-88
lines changed

2 files changed

+34
-88
lines changed

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,18 +343,18 @@ def _create_a2a_request_for_user_function_response(
343343

344344
def _construct_message_parts_from_session(
345345
self, ctx: InvocationContext
346-
) -> tuple[list[A2APart], Optional[str], Optional[str]]:
346+
) -> tuple[list[A2APart], Optional[str]]:
347347
"""Construct A2A message parts from session events.
348348
349349
Args:
350350
ctx: The invocation context
351351
352352
Returns:
353-
List of A2A parts extracted from session events, context ID, task ID
353+
List of A2A parts extracted from session events, context ID,
354+
request metadata
354355
"""
355356
message_parts: list[A2APart] = []
356357
context_id = None
357-
task_id = None
358358

359359
events_to_process = []
360360
for event in reversed(ctx.session.events):
@@ -364,7 +364,6 @@ def _construct_message_parts_from_session(
364364
if event.custom_metadata:
365365
metadata = event.custom_metadata
366366
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
367-
task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
368367
break
369368
events_to_process.append(event)
370369

@@ -385,7 +384,7 @@ def _construct_message_parts_from_session(
385384
else:
386385
logger.warning("Failed to convert part to A2A format: %s", part)
387386

388-
return message_parts, context_id, task_id
387+
return message_parts, context_id
389388

390389
async def _handle_a2a_response(
391390
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
@@ -501,8 +500,8 @@ async def _run_async_impl(
501500
# Create A2A request for function response or regular message
502501
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
503502
if not a2a_request:
504-
message_parts, context_id, task_id = (
505-
self._construct_message_parts_from_session(ctx)
503+
message_parts, context_id = self._construct_message_parts_from_session(
504+
ctx
506505
)
507506

508507
if not message_parts:
@@ -522,7 +521,6 @@ async def _run_async_impl(
522521
parts=message_parts,
523522
role="user",
524523
context_id=context_id,
525-
task_id=task_id,
526524
)
527525

528526
logger.debug(build_a2a_request_log(a2a_request))

tests/unittests/agents/test_remote_a2a_agent.py

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import json
1616
from pathlib import Path
1717
import tempfile
18-
from unittest import mock
1918
from unittest.mock import AsyncMock
2019
from unittest.mock import create_autospec
2120
from unittest.mock import Mock
@@ -613,14 +612,13 @@ def test_construct_message_parts_from_session_success(self):
613612
mock_a2a_part = Mock()
614613
self.mock_genai_part_converter.return_value = mock_a2a_part
615614

616-
parts, context_id, task_id = (
617-
self.agent._construct_message_parts_from_session(self.mock_context)
615+
parts, context_id = self.agent._construct_message_parts_from_session(
616+
self.mock_context
618617
)
619618

620619
assert len(parts) == 1
621620
assert parts[0] == mock_a2a_part
622621
assert context_id is None
623-
assert task_id is None
624622

625623
def test_construct_message_parts_from_session_success_multiple_parts(self):
626624
"""Test successful message parts construction from session."""
@@ -648,54 +646,23 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
648646
mock_a2a_part2,
649647
]
650648

651-
parts, context_id, task_id = (
652-
self.agent._construct_message_parts_from_session(self.mock_context)
649+
parts, context_id = self.agent._construct_message_parts_from_session(
650+
self.mock_context
653651
)
654652

655653
assert parts == [mock_a2a_part1, mock_a2a_part2]
656654
assert context_id is None
657-
assert task_id is None
658655

659656
def test_construct_message_parts_from_session_empty_events(self):
660657
"""Test message parts construction with empty events."""
661658
self.mock_session.events = []
662659

663-
parts, context_id, task_id = (
664-
self.agent._construct_message_parts_from_session(self.mock_context)
660+
parts, context_id = self.agent._construct_message_parts_from_session(
661+
self.mock_context
665662
)
666663

667664
assert parts == []
668665
assert context_id is None
669-
assert task_id is None
670-
671-
def test_construct_message_parts_from_session_reads_ids_from_metadata(self):
672-
"""Metadata from last agent event is reused for context and task IDs."""
673-
mock_part = Mock()
674-
mock_part.text = "User message"
675-
mock_content = Mock()
676-
mock_content.parts = [mock_part]
677-
user_event = Mock()
678-
user_event.content = mock_content
679-
user_event.author = "user"
680-
681-
agent_event = Mock()
682-
agent_event.author = self.agent.name
683-
agent_event.custom_metadata = {
684-
A2A_METADATA_PREFIX + "context_id": "context-xyz",
685-
A2A_METADATA_PREFIX + "task_id": "task-abc",
686-
}
687-
688-
# Agent reply is before the latest user message (chronological order).
689-
self.mock_session.events = [agent_event, user_event]
690-
self.mock_genai_part_converter.return_value = Mock()
691-
692-
parts, context_id, task_id = (
693-
self.agent._construct_message_parts_from_session(self.mock_context)
694-
)
695-
696-
assert len(parts) == 1 # the latest user message
697-
assert context_id == "context-xyz"
698-
assert task_id == "task-abc"
699666

700667
@pytest.mark.asyncio
701668
async def test_handle_a2a_response_success_with_message(self):
@@ -819,14 +786,13 @@ def mock_converter(part):
819786

820787
self.mock_genai_part_converter.side_effect = mock_converter
821788

822-
parts, context_id, task_id = (
823-
self.agent._construct_message_parts_from_session(self.mock_context)
789+
parts, context_id = self.agent._construct_message_parts_from_session(
790+
self.mock_context
824791
)
825792

826793
# Verify the parts are in correct order
827794
assert len(parts) == 3 # 1 user part + 2 other agent parts
828795
assert context_id is None
829-
assert task_id is None
830796

831797
# Verify order: user part, then "For context:", then agent message
832798
assert converted_parts[0].original_text == "User question"
@@ -1143,14 +1109,24 @@ def test_construct_message_parts_from_session_success(self):
11431109
mock_a2a_part = Mock()
11441110
mock_convert_part.return_value = mock_a2a_part
11451111

1146-
parts, context_id, task_id = (
1147-
self.agent._construct_message_parts_from_session(self.mock_context)
1112+
parts, context_id = self.agent._construct_message_parts_from_session(
1113+
self.mock_context
11481114
)
11491115

11501116
assert len(parts) == 1
11511117
assert parts[0] == mock_a2a_part
11521118
assert context_id is None
1153-
assert task_id is None
1119+
1120+
def test_construct_message_parts_from_session_empty_events(self):
1121+
"""Test message parts construction with empty events."""
1122+
self.mock_session.events = []
1123+
1124+
parts, context_id = self.agent._construct_message_parts_from_session(
1125+
self.mock_context
1126+
)
1127+
1128+
assert parts == []
1129+
assert context_id is None
11541130

11551131
@pytest.mark.asyncio
11561132
async def test_handle_a2a_response_success_with_message(self):
@@ -1487,8 +1463,7 @@ async def test_run_async_impl_no_message_parts(self):
14871463
mock_construct.return_value = (
14881464
[],
14891465
None,
1490-
None,
1491-
) # Tuple with empty parts and no context/task ids
1466+
) # Tuple with empty parts and no context_id
14921467

14931468
events = []
14941469
async for event in self.agent._run_async_impl(self.mock_context):
@@ -1518,8 +1493,7 @@ async def test_run_async_impl_successful_request(self):
15181493
mock_construct.return_value = (
15191494
[mock_a2a_part],
15201495
"context-123",
1521-
"task-789",
1522-
) # Tuple with parts and context/task ids
1496+
) # Tuple with parts and context_id
15231497

15241498
# Mock A2A client
15251499
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1571,13 +1545,6 @@ async def test_run_async_impl_successful_request(self):
15711545
A2A_METADATA_PREFIX + "request"
15721546
in mock_event.custom_metadata
15731547
)
1574-
mock_message_class.assert_called_once_with(
1575-
message_id=mock.ANY,
1576-
parts=[mock_a2a_part],
1577-
role="user",
1578-
context_id="context-123",
1579-
task_id="task-789",
1580-
)
15811548

15821549
@pytest.mark.asyncio
15831550
async def test_run_async_impl_a2a_client_error(self):
@@ -1598,8 +1565,7 @@ async def test_run_async_impl_a2a_client_error(self):
15981565
mock_construct.return_value = (
15991566
[mock_a2a_part],
16001567
"context-123",
1601-
"task-789",
1602-
) # Tuple with parts and context/task ids
1568+
) # Tuple with parts and context_id
16031569

16041570
# Mock A2A client that throws an exception
16051571
mock_a2a_client = AsyncMock()
@@ -1666,8 +1632,7 @@ async def test_run_async_impl_with_meta_provider(self):
16661632
mock_construct.return_value = (
16671633
[mock_a2a_part],
16681634
"context-123",
1669-
"task-789",
1670-
) # Tuple with parts and context/task ids
1635+
) # Tuple with parts and context_id
16711636

16721637
# Mock A2A client
16731638
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1718,13 +1683,6 @@ async def test_run_async_impl_with_meta_provider(self):
17181683
request=mock_message,
17191684
request_metadata=request_metadata,
17201685
)
1721-
mock_message_class.assert_called_once_with(
1722-
message_id=mock.ANY,
1723-
parts=[mock_a2a_part],
1724-
role="user",
1725-
context_id="context-123",
1726-
task_id="task-789",
1727-
)
17281686

17291687

17301688
class TestRemoteA2aAgentExecutionFromFactory:
@@ -1779,8 +1737,7 @@ async def test_run_async_impl_no_message_parts(self):
17791737
mock_construct.return_value = (
17801738
[],
17811739
None,
1782-
None,
1783-
) # Tuple with empty parts and no context/task ids
1740+
) # Tuple with empty parts and no context_id
17841741

17851742
events = []
17861743
async for event in self.agent._run_async_impl(self.mock_context):
@@ -1810,8 +1767,7 @@ async def test_run_async_impl_successful_request(self):
18101767
mock_construct.return_value = (
18111768
[mock_a2a_part],
18121769
"context-123",
1813-
"task-789",
1814-
) # Tuple with parts and context/task ids
1770+
) # Tuple with parts and context_id
18151771

18161772
# Mock A2A client
18171773
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1865,13 +1821,6 @@ async def test_run_async_impl_successful_request(self):
18651821
A2A_METADATA_PREFIX + "request"
18661822
in mock_event.custom_metadata
18671823
)
1868-
mock_message_class.assert_called_once_with(
1869-
message_id=mock.ANY,
1870-
parts=[mock_a2a_part],
1871-
role="user",
1872-
context_id="context-123",
1873-
task_id="task-789",
1874-
)
18751824

18761825
@pytest.mark.asyncio
18771826
async def test_run_async_impl_a2a_client_error(self):
@@ -1892,8 +1841,7 @@ async def test_run_async_impl_a2a_client_error(self):
18921841
mock_construct.return_value = (
18931842
[mock_a2a_part],
18941843
"context-123",
1895-
"task-789",
1896-
) # Tuple with parts and context/task ids
1844+
) # Tuple with parts and context_id
18971845

18981846
# Mock A2A client that throws an exception
18991847
mock_a2a_client = AsyncMock()

0 commit comments

Comments
 (0)