-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel
#3133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel
#3133
Changes from 2 commits
bea4e20
6f8cfc7
12b102c
34e6d2f
969353f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1178,7 +1178,7 @@ async def _responses_create( | |
| truncation=model_settings.get('openai_truncation', NOT_GIVEN), | ||
| timeout=model_settings.get('timeout', NOT_GIVEN), | ||
| service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), | ||
| previous_response_id=previous_response_id or NOT_GIVEN, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a broken merge conflict resolution! Please remove it from the diff to make sure we don't accidentally merge this into |
||
| previous_response_id=previous_response_id, | ||
| reasoning=reasoning, | ||
| user=model_settings.get('openai_user', NOT_GIVEN), | ||
| text=text or NOT_GIVEN, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3346,7 +3346,7 @@ class Result(BaseModel): | |
| assert response_stream.usage() == snapshot( | ||
| RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1) | ||
| ) | ||
| assert run.usage() == snapshot(RunUsage(requests=1)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a breaking change we shouldn't make
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, this would be a breaking change, so I'd rather ensure at the call site where we call |
||
| assert run.usage() == snapshot(RunUsage()) | ||
| assert run.usage() == snapshot( | ||
| RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1) | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| ModelRequest, | ||
| ModelResponse, | ||
| RunContext, | ||
| TextPart, | ||
| ToolCallPart, | ||
| ToolReturnPart, | ||
| UsageLimitExceeded, | ||
|
|
@@ -200,7 +201,7 @@ async def test_multi_agent_usage_sync(): | |
| controller_agent = Agent(TestModel()) | ||
|
|
||
| @controller_agent.tool | ||
| def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: | ||
| async def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: | ||
|
||
| new_usage = RunUsage(requests=5, input_tokens=2, output_tokens=3) | ||
| ctx.usage.incr(new_usage) | ||
| return 0 | ||
|
|
@@ -248,6 +249,41 @@ def test_add_usages(): | |
| assert RunUsage() + RunUsage() == RunUsage() | ||
|
|
||
|
|
||
| def test_run_usage_with_request_usage(): | ||
| """Test RunUsage operations with RequestUsage to ensure coverage of RequestUsage branches.""" | ||
| run_usage = RunUsage(requests=1, input_tokens=10, output_tokens=20, tool_calls=1) | ||
| request_usage = RequestUsage(input_tokens=5, output_tokens=10) | ||
|
|
||
| # Test __add__ with RequestUsage | ||
| result = run_usage + request_usage | ||
| assert result.requests == 2 # 1 + 1 (RequestUsage.requests property returns 1) | ||
| assert result.input_tokens == 15 | ||
| assert result.output_tokens == 30 | ||
| assert result.tool_calls == 1 # RequestUsage doesn't have tool_calls | ||
|
|
||
| # Test incr with RequestUsage (covers elif isinstance(incr_usage, RequestUsage) branch) | ||
| run_usage2 = RunUsage(requests=2, input_tokens=20, output_tokens=30, tool_calls=2) | ||
| run_usage2.incr(request_usage) | ||
| assert run_usage2.requests == 3 # 2 + 1 | ||
| assert run_usage2.input_tokens == 25 # 20 + 5 | ||
| assert run_usage2.output_tokens == 40 # 30 + 10 | ||
| assert run_usage2.tool_calls == 2 # Unchanged | ||
|
|
||
| # Test incr with empty details dict (covers empty for loop branch in _incr_usage_tokens) | ||
| run_usage3 = RunUsage(requests=0, tool_calls=0) | ||
| request_usage_no_details = RequestUsage(input_tokens=5, output_tokens=10) | ||
| assert request_usage_no_details.details == {} # Ensure details is empty | ||
| run_usage3.incr(request_usage_no_details) | ||
| assert run_usage3.requests == 1 | ||
| assert run_usage3.details == {} | ||
|
|
||
| # Test incr with non-empty details dict | ||
| run_usage4 = RunUsage(requests=0, tool_calls=0, details={'reasoning_tokens': 10}) | ||
| request_usage_with_details = RequestUsage(input_tokens=5, output_tokens=10, details={'reasoning_tokens': 5}) | ||
| run_usage4.incr(request_usage_with_details) | ||
| assert run_usage4.details == {'reasoning_tokens': 15} | ||
|
|
||
|
|
||
| async def test_tool_call_limit() -> None: | ||
| test_agent = Agent(TestModel()) | ||
|
|
||
|
|
@@ -355,6 +391,41 @@ def test_deprecated_usage_limits(): | |
| assert UsageLimits(output_tokens_limit=100).response_tokens_limit == 100 # type: ignore | ||
|
|
||
|
|
||
| async def test_race_condition_parallel_tool_calls(): | ||
| """Test that demonstrates race condition in parallel tool execution. | ||
|
|
||
| This test would fail intermittently on main without the fix because multiple | ||
| asyncio tasks calling usage.incr() can interleave their read-modify-write operations. | ||
| """ | ||
| # Run multiple iterations to increase chance of catching race condition | ||
| for iteration in range(20): | ||
| call_count = 0 | ||
|
|
||
| def parallel_tools_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| if call_count == 1: | ||
| # Return 10 parallel tool calls for more contention | ||
| return ModelResponse(parts=[ToolCallPart('tool_a', {}, f'call_{i}') for i in range(10)]) | ||
| else: | ||
| # Return final text response | ||
| return ModelResponse(parts=[TextPart(content='done')]) | ||
|
|
||
| agent = Agent(FunctionModel(parallel_tools_model)) | ||
|
|
||
| @agent.tool_plain | ||
| async def tool_a() -> str: | ||
| # Add multiple await points to increase chance of task interleaving | ||
| await asyncio.sleep(0.0001) | ||
| await asyncio.sleep(0.0001) | ||
| return 'result' | ||
|
|
||
| result = await agent.run('test') | ||
| # Without proper synchronization, tool_calls might be undercounted | ||
| actual = result.usage().tool_calls | ||
| assert actual == 10, f'Iteration {iteration}: Expected 10 tool calls, got {actual}' | ||
|
|
||
|
|
||
| async def test_parallel_tool_calls_limit_enforced(): | ||
| """Parallel tool calls must not exceed the limit and should raise immediately.""" | ||
| executed_tools: list[str] = [] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.