Skip to content

Commit 6da6c2a

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: using async lock for accessing shared object in parallel executions and update tests for testing various type of functions
1. given we are running parallel functions in one event loop (one thread) , we should use async lock instead of thread lock 2. test three kind of functions: a. sync function b. async function that doesn't yield c. async function that yield PiperOrigin-RevId: 791255012
1 parent 8ef2177 commit 6da6c2a

File tree

2 files changed

+167
-40
lines changed

2 files changed

+167
-40
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ async def handle_function_calls_live(
320320
if not function_calls:
321321
return None
322322

323-
# Create thread-safe lock for active_streaming_tools modifications
324-
streaming_lock = threading.Lock()
323+
# Create async lock for active_streaming_tools modifications
324+
streaming_lock = asyncio.Lock()
325325

326326
# Create tasks for parallel execution
327327
tasks = [
@@ -368,7 +368,7 @@ async def _execute_single_function_call_live(
368368
function_call: types.FunctionCall,
369369
tools_dict: dict[str, BaseTool],
370370
agent: LlmAgent,
371-
streaming_lock: threading.Lock,
371+
streaming_lock: asyncio.Lock,
372372
) -> Optional[Event]:
373373
"""Execute a single function call for live mode with thread safety."""
374374
tool, tool_context = _get_tool_and_context(
@@ -448,7 +448,7 @@ async def _process_function_live_helper(
448448
function_call,
449449
function_args,
450450
invocation_context,
451-
streaming_lock: threading.Lock,
451+
streaming_lock: asyncio.Lock,
452452
):
453453
function_response = None
454454
# Check if this is a stop_streaming function call
@@ -458,7 +458,7 @@ async def _process_function_live_helper(
458458
):
459459
function_name = function_args['function_name']
460460
# Thread-safe access to active_streaming_tools
461-
with streaming_lock:
461+
async with streaming_lock:
462462
active_tasks = invocation_context.active_streaming_tools
463463
if (
464464
active_tasks
@@ -491,7 +491,7 @@ async def _process_function_live_helper(
491491
}
492492
if not function_response:
493493
# Clean up the reference under lock
494-
with streaming_lock:
494+
async with streaming_lock:
495495
if (
496496
invocation_context.active_streaming_tools
497497
and function_name in invocation_context.active_streaming_tools
@@ -533,7 +533,7 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
533533
)
534534

535535
# Register streaming tool using original logic
536-
with streaming_lock:
536+
async with streaming_lock:
537537
if invocation_context.active_streaming_tools is None:
538538
invocation_context.active_streaming_tools = {}
539539

tests/unittests/flows/llm_flows/test_functions_simple.py

Lines changed: 160 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -843,30 +843,77 @@ def modify_state_2(tool_context: ToolContext) -> dict:
843843

844844

845845
@pytest.mark.asyncio
846-
async def test_parallel_mixed_sync_async_functions():
847-
"""Test parallel execution with mix of sync and async functions."""
848-
execution_log = []
846+
async def test_sync_function_blocks_async_functions():
847+
"""Test that sync functions block async functions from running concurrently."""
848+
execution_order = []
849+
850+
def blocking_sync_function() -> dict:
851+
execution_order.append('sync_A')
852+
# Simulate CPU-intensive work that blocks the event loop
853+
result = 0
854+
for i in range(1000000): # This blocks the event loop
855+
result += i
856+
execution_order.append('sync_B')
857+
return {'result': 'sync_done'}
858+
859+
async def yielding_async_function() -> dict:
860+
execution_order.append('async_C')
861+
await asyncio.sleep(
862+
0.001
863+
) # This should yield, but can't if event loop is blocked
864+
execution_order.append('async_D')
865+
return {'result': 'async_done'}
866+
867+
# Create function calls - these should run "in parallel"
868+
function_calls = [
869+
types.Part.from_function_call(name='blocking_sync_function', args={}),
870+
types.Part.from_function_call(name='yielding_async_function', args={}),
871+
]
872+
873+
responses: list[types.Content] = [function_calls, 'response1']
874+
mock_model = testing_utils.MockModel.create(responses=responses)
875+
876+
agent = Agent(
877+
name='test_agent',
878+
model=mock_model,
879+
tools=[blocking_sync_function, yielding_async_function],
880+
)
881+
runner = testing_utils.TestInMemoryRunner(agent)
882+
events = await runner.run_async_with_new_session('test')
849883

850-
def sync_function(value: int) -> dict:
851-
execution_log.append(f'sync_start_{value}')
852-
# Simulate some work
853-
import time
884+
# With blocking sync function, execution should be sequential: A, B, C, D
885+
# The sync function blocks, preventing the async function from yielding properly
886+
assert execution_order == ['sync_A', 'sync_B', 'async_C', 'async_D']
854887

855-
time.sleep(0.05) # 50ms
856-
execution_log.append(f'sync_end_{value}')
857-
return {'result': f'sync_{value}'}
858888

859-
async def async_function(value: int) -> dict:
860-
execution_log.append(f'async_start_{value}')
861-
await asyncio.sleep(0.05) # 50ms
862-
execution_log.append(f'async_end_{value}')
863-
return {'result': f'async_{value}'}
889+
@pytest.mark.asyncio
890+
async def test_async_function_without_yield_blocks_others():
891+
"""Test that async functions without yield statements block other functions."""
892+
execution_order = []
893+
894+
async def non_yielding_async_function() -> dict:
895+
execution_order.append('non_yield_A')
896+
# CPU-intensive work without any await statements - blocks like sync function
897+
result = 0
898+
for i in range(1000000): # No await here, so this blocks the event loop
899+
result += i
900+
execution_order.append('non_yield_B')
901+
return {'result': 'non_yielding_done'}
902+
903+
async def yielding_async_function() -> dict:
904+
execution_order.append('yield_C')
905+
await asyncio.sleep(
906+
0.001
907+
) # This should yield, but can't if event loop is blocked
908+
execution_order.append('yield_D')
909+
return {'result': 'yielding_done'}
864910

865911
# Create function calls
866912
function_calls = [
867-
types.Part.from_function_call(name='sync_function', args={'value': 1}),
868-
types.Part.from_function_call(name='async_function', args={'value': 2}),
869-
types.Part.from_function_call(name='sync_function', args={'value': 3}),
913+
types.Part.from_function_call(
914+
name='non_yielding_async_function', args={}
915+
),
916+
types.Part.from_function_call(name='yielding_async_function', args={}),
870917
]
871918

872919
responses: list[types.Content] = [function_calls, 'response1']
@@ -875,24 +922,104 @@ async def async_function(value: int) -> dict:
875922
agent = Agent(
876923
name='test_agent',
877924
model=mock_model,
878-
tools=[sync_function, async_function],
925+
tools=[non_yielding_async_function, yielding_async_function],
879926
)
880927
runner = testing_utils.TestInMemoryRunner(agent)
928+
events = await runner.run_async_with_new_session('test')
881929

882-
import time
930+
# Non-yielding async function blocks, so execution is sequential: A, B, C, D
931+
assert execution_order == ['non_yield_A', 'non_yield_B', 'yield_C', 'yield_D']
883932

884-
start_time = time.time()
933+
934+
@pytest.mark.asyncio
935+
async def test_yielding_async_functions_run_concurrently():
936+
"""Test that async functions with proper yields run concurrently."""
937+
execution_order = []
938+
939+
async def yielding_async_function_1() -> dict:
940+
execution_order.append('func1_A')
941+
await asyncio.sleep(0.001) # Yield control
942+
execution_order.append('func1_B')
943+
return {'result': 'func1_done'}
944+
945+
async def yielding_async_function_2() -> dict:
946+
execution_order.append('func2_C')
947+
await asyncio.sleep(0.001) # Yield control
948+
execution_order.append('func2_D')
949+
return {'result': 'func2_done'}
950+
951+
# Create function calls
952+
function_calls = [
953+
types.Part.from_function_call(name='yielding_async_function_1', args={}),
954+
types.Part.from_function_call(name='yielding_async_function_2', args={}),
955+
]
956+
957+
responses: list[types.Content] = [function_calls, 'response1']
958+
mock_model = testing_utils.MockModel.create(responses=responses)
959+
960+
agent = Agent(
961+
name='test_agent',
962+
model=mock_model,
963+
tools=[yielding_async_function_1, yielding_async_function_2],
964+
)
965+
runner = testing_utils.TestInMemoryRunner(agent)
885966
events = await runner.run_async_with_new_session('test')
886-
total_time = time.time() - start_time
887967

888-
# Should complete in less than 120ms (parallel) rather than 150ms (sequential)
889-
# Allow for overhead from task creation and synchronization
890-
assert total_time < 0.12, f'Execution took {total_time}s, expected < 0.12s'
891-
892-
# Verify all functions were called
893-
assert 'sync_start_1' in execution_log
894-
assert 'sync_end_1' in execution_log
895-
assert 'async_start_2' in execution_log
896-
assert 'async_end_2' in execution_log
897-
assert 'sync_start_3' in execution_log
898-
assert 'sync_end_3' in execution_log
968+
# With proper yielding, execution should interleave: A, C, B, D
969+
# Both functions start, yield, then complete
970+
assert execution_order == ['func1_A', 'func2_C', 'func1_B', 'func2_D']
971+
972+
973+
@pytest.mark.asyncio
974+
async def test_mixed_function_types_execution_order():
975+
"""Test execution order with all three types of functions."""
976+
execution_order = []
977+
978+
def sync_function() -> dict:
979+
execution_order.append('sync_A')
980+
# Small amount of blocking work
981+
result = sum(range(100000))
982+
execution_order.append('sync_B')
983+
return {'result': 'sync_done'}
984+
985+
async def non_yielding_async() -> dict:
986+
execution_order.append('non_yield_C')
987+
# CPU work without yield
988+
result = sum(range(100000))
989+
execution_order.append('non_yield_D')
990+
return {'result': 'non_yield_done'}
991+
992+
async def yielding_async() -> dict:
993+
execution_order.append('yield_E')
994+
await asyncio.sleep(0.001) # Proper yield
995+
execution_order.append('yield_F')
996+
return {'result': 'yield_done'}
997+
998+
# Create function calls
999+
function_calls = [
1000+
types.Part.from_function_call(name='sync_function', args={}),
1001+
types.Part.from_function_call(name='non_yielding_async', args={}),
1002+
types.Part.from_function_call(name='yielding_async', args={}),
1003+
]
1004+
1005+
responses: list[types.Content] = [function_calls, 'response1']
1006+
mock_model = testing_utils.MockModel.create(responses=responses)
1007+
1008+
agent = Agent(
1009+
name='test_agent',
1010+
model=mock_model,
1011+
tools=[sync_function, non_yielding_async, yielding_async],
1012+
)
1013+
runner = testing_utils.TestInMemoryRunner(agent)
1014+
events = await runner.run_async_with_new_session('test')
1015+
1016+
# All blocking functions run sequentially, then the yielding one
1017+
# Expected order: sync_A, sync_B, non_yield_C, non_yield_D, yield_E, yield_F
1018+
assert execution_order == [
1019+
'sync_A',
1020+
'sync_B',
1021+
'non_yield_C',
1022+
'non_yield_D',
1023+
'yield_E',
1024+
'yield_F',
1025+
]

0 commit comments

Comments
 (0)