|
19 | 19 | from google.adk.agents.llm_agent import Agent
|
20 | 20 | from google.adk.events.event import Event
|
21 | 21 | from google.adk.flows.llm_flows.functions import find_matching_function_call
|
| 22 | +from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events |
22 | 23 | from google.adk.tools.function_tool import FunctionTool
|
23 | 24 | from google.adk.tools.tool_context import ToolContext
|
24 | 25 | from google.genai import types
|
@@ -931,6 +932,121 @@ async def yielding_async_function() -> dict:
|
931 | 932 | assert execution_order == ['non_yield_A', 'non_yield_B', 'yield_C', 'yield_D']
|
932 | 933 |
|
933 | 934 |
|
| 935 | +def test_merge_parallel_function_response_events_preserves_invocation_id(): |
| 936 | + """Test that merge_parallel_function_response_events preserves the base event's invocation_id.""" |
| 937 | + # Create multiple function response events with different invocation IDs |
| 938 | + invocation_id = 'base_invocation_123' |
| 939 | + |
| 940 | + function_response1 = types.FunctionResponse( |
| 941 | + id='func_123', name='test_function1', response={'result': 'success1'} |
| 942 | + ) |
| 943 | + |
| 944 | + function_response2 = types.FunctionResponse( |
| 945 | + id='func_456', name='test_function2', response={'result': 'success2'} |
| 946 | + ) |
| 947 | + |
| 948 | + event1 = Event( |
| 949 | + invocation_id=invocation_id, |
| 950 | + author='test_agent', |
| 951 | + content=types.Content( |
| 952 | + role='user', parts=[types.Part(function_response=function_response1)] |
| 953 | + ), |
| 954 | + ) |
| 955 | + |
| 956 | + event2 = Event( |
| 957 | + invocation_id='different_invocation_456', # Different invocation ID |
| 958 | + author='test_agent', |
| 959 | + content=types.Content( |
| 960 | + role='user', parts=[types.Part(function_response=function_response2)] |
| 961 | + ), |
| 962 | + ) |
| 963 | + |
| 964 | + # Merge the events |
| 965 | + merged_event = merge_parallel_function_response_events([event1, event2]) |
| 966 | + |
| 967 | + # Should preserve the base event's (first event's) invocation_id |
| 968 | + assert merged_event.invocation_id == invocation_id |
| 969 | + assert merged_event.invocation_id != 'different_invocation_456' |
| 970 | + |
| 971 | + # Should contain both function responses |
| 972 | + assert len(merged_event.content.parts) == 2 |
| 973 | + |
| 974 | + # Verify the responses are preserved |
| 975 | + response_ids = { |
| 976 | + part.function_response.id for part in merged_event.content.parts |
| 977 | + } |
| 978 | + assert 'func_123' in response_ids |
| 979 | + assert 'func_456' in response_ids |
| 980 | + |
| 981 | + |
| 982 | +def test_merge_parallel_function_response_events_single_event(): |
| 983 | + """Test that merge_parallel_function_response_events returns single event unchanged.""" |
| 984 | + invocation_id = 'single_invocation_123' |
| 985 | + |
| 986 | + function_response = types.FunctionResponse( |
| 987 | + id='func_123', name='test_function', response={'result': 'success'} |
| 988 | + ) |
| 989 | + |
| 990 | + event = Event( |
| 991 | + invocation_id=invocation_id, |
| 992 | + author='test_agent', |
| 993 | + content=types.Content( |
| 994 | + role='user', parts=[types.Part(function_response=function_response)] |
| 995 | + ), |
| 996 | + ) |
| 997 | + |
| 998 | + # Merge single event |
| 999 | + merged_event = merge_parallel_function_response_events([event]) |
| 1000 | + |
| 1001 | + # Should return the same event object |
| 1002 | + assert merged_event is event |
| 1003 | + assert merged_event.invocation_id == invocation_id |
| 1004 | + |
| 1005 | + |
| 1006 | +def test_merge_parallel_function_response_events_preserves_other_attributes(): |
| 1007 | + """Test that merge_parallel_function_response_events preserves other attributes from base event.""" |
| 1008 | + invocation_id = 'base_invocation_123' |
| 1009 | + base_author = 'base_agent' |
| 1010 | + base_branch = 'main_branch' |
| 1011 | + |
| 1012 | + function_response1 = types.FunctionResponse( |
| 1013 | + id='func_123', name='test_function1', response={'result': 'success1'} |
| 1014 | + ) |
| 1015 | + |
| 1016 | + function_response2 = types.FunctionResponse( |
| 1017 | + id='func_456', name='test_function2', response={'result': 'success2'} |
| 1018 | + ) |
| 1019 | + |
| 1020 | + event1 = Event( |
| 1021 | + invocation_id=invocation_id, |
| 1022 | + author=base_author, |
| 1023 | + branch=base_branch, |
| 1024 | + content=types.Content( |
| 1025 | + role='user', parts=[types.Part(function_response=function_response1)] |
| 1026 | + ), |
| 1027 | + ) |
| 1028 | + |
| 1029 | + event2 = Event( |
| 1030 | + invocation_id='different_invocation_456', |
| 1031 | + author='different_agent', # Different author |
| 1032 | + branch='different_branch', # Different branch |
| 1033 | + content=types.Content( |
| 1034 | + role='user', parts=[types.Part(function_response=function_response2)] |
| 1035 | + ), |
| 1036 | + ) |
| 1037 | + |
| 1038 | + # Merge the events |
| 1039 | + merged_event = merge_parallel_function_response_events([event1, event2]) |
| 1040 | + |
| 1041 | + # Should preserve base event's attributes |
| 1042 | + assert merged_event.invocation_id == invocation_id |
| 1043 | + assert merged_event.author == base_author |
| 1044 | + assert merged_event.branch == base_branch |
| 1045 | + |
| 1046 | + # Should contain both function responses |
| 1047 | + assert len(merged_event.content.parts) == 2 |
| 1048 | + |
| 1049 | + |
934 | 1050 | @pytest.mark.asyncio
|
935 | 1051 | async def test_yielding_async_functions_run_concurrently():
|
936 | 1052 | """Test that async functions with proper yields run concurrently."""
|
|
0 commit comments