-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel
#3133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix RunUsage.tool_calls being undercounted due to race condition when running tools in parallel
#3133
Changes from 1 commit
bea4e20
6f8cfc7
12b102c
34e6d2f
969353f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -408,7 +408,7 @@ async def stream( | |
| message_history, model_settings, model_request_parameters, run_context | ||
| ) as streamed_response: | ||
| self._did_stream = True | ||
| ctx.state.usage.requests += 1 | ||
| # Request count is incremented in _finish_handling via response.usage | ||
| agent_stream = result.AgentStream[DepsT, T]( | ||
| _raw_stream_response=streamed_response, | ||
| _output_schema=ctx.deps.output_schema, | ||
|
|
@@ -426,7 +426,7 @@ async def stream( | |
|
|
||
| model_response = streamed_response.get() | ||
|
|
||
| self._finish_handling(ctx, model_response) | ||
| await self._finish_handling(ctx, model_response) | ||
| assert self._result is not None # this should be set by the previous line | ||
|
|
||
| async def _make_request( | ||
|
|
@@ -437,9 +437,9 @@ async def _make_request( | |
|
|
||
| model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) | ||
| model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) | ||
| ctx.state.usage.requests += 1 | ||
| # Request count is incremented in _finish_handling via response.usage | ||
|
|
||
| return self._finish_handling(ctx, model_response) | ||
| return await self._finish_handling(ctx, model_response) | ||
|
|
||
| async def _prepare_request( | ||
| self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] | ||
|
|
@@ -481,7 +481,7 @@ async def _prepare_request( | |
|
|
||
| return model_settings, model_request_parameters, message_history, run_context | ||
|
|
||
| def _finish_handling( | ||
| async def _finish_handling( | ||
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| response: _messages.ModelResponse, | ||
|
|
@@ -895,6 +895,8 @@ async def _call_tools( | |
| tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} | ||
| user_parts_by_index: dict[int, _messages.UserPromptPart] = {} | ||
| deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {} | ||
| # Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions | ||
| usage_lock = asyncio.Lock() | ||
|
||
|
|
||
| if usage_limits.tool_calls_limit is not None: | ||
| projected_usage = deepcopy(usage) | ||
|
|
@@ -904,74 +906,85 @@ async def _call_tools( | |
| for call in tool_calls: | ||
| yield _messages.FunctionToolCallEvent(call) | ||
|
|
||
| with tracer.start_as_current_span( | ||
| 'running tools', | ||
| attributes={ | ||
| 'tools': [call.tool_name for call in tool_calls], | ||
| 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}', | ||
| }, | ||
| ): | ||
| # Import and set the usage lock context variable for parallel tool execution | ||
| from ._tool_manager import _usage_increment_lock_ctx_var # pyright: ignore[reportPrivateUsage] | ||
|
|
||
| async def handle_call_or_result( | ||
| coro_or_task: Awaitable[ | ||
| tuple[ | ||
| _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None | ||
| ] | ||
| ] | ||
| | Task[ | ||
| tuple[ | ||
| _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None | ||
| ] | ||
| ], | ||
| index: int, | ||
| ) -> _messages.HandleResponseEvent | None: | ||
| try: | ||
| tool_part, tool_user_content = ( | ||
| (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() | ||
| ) | ||
| except exceptions.CallDeferred: | ||
| deferred_calls_by_index[index] = 'external' | ||
| except exceptions.ApprovalRequired: | ||
| deferred_calls_by_index[index] = 'unapproved' | ||
| else: | ||
| tool_parts_by_index[index] = tool_part | ||
| if tool_user_content: | ||
| user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content) | ||
| token = _usage_increment_lock_ctx_var.set(usage_lock) | ||
|
|
||
| return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) | ||
|
|
||
| if tool_manager.should_call_sequentially(tool_calls): | ||
| for index, call in enumerate(tool_calls): | ||
| if event := await handle_call_or_result( | ||
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| index, | ||
| ): | ||
| yield event | ||
| try: | ||
| with tracer.start_as_current_span( | ||
| 'running tools', | ||
| attributes={ | ||
| 'tools': [call.tool_name for call in tool_calls], | ||
| 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}', | ||
| }, | ||
| ): | ||
|
|
||
| else: | ||
| tasks = [ | ||
| asyncio.create_task( | ||
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| name=call.tool_name, | ||
| ) | ||
| for call in tool_calls | ||
| ] | ||
|
|
||
| pending = tasks | ||
| while pending: | ||
| done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) | ||
| for task in done: | ||
| index = tasks.index(task) | ||
| if event := await handle_call_or_result(coro_or_task=task, index=index): | ||
| async def handle_call_or_result( | ||
| coro_or_task: Awaitable[ | ||
| tuple[ | ||
| _messages.ToolReturnPart | _messages.RetryPromptPart, | ||
| str | Sequence[_messages.UserContent] | None, | ||
| ] | ||
| ] | ||
| | Task[ | ||
| tuple[ | ||
| _messages.ToolReturnPart | _messages.RetryPromptPart, | ||
| str | Sequence[_messages.UserContent] | None, | ||
| ] | ||
| ], | ||
| index: int, | ||
| ) -> _messages.HandleResponseEvent | None: | ||
| try: | ||
| tool_part, tool_user_content = ( | ||
| (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() | ||
| ) | ||
| except exceptions.CallDeferred: | ||
| deferred_calls_by_index[index] = 'external' | ||
| except exceptions.ApprovalRequired: | ||
| deferred_calls_by_index[index] = 'unapproved' | ||
| else: | ||
| tool_parts_by_index[index] = tool_part | ||
| if tool_user_content: | ||
| user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content) | ||
|
|
||
| return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) | ||
|
|
||
| if tool_manager.should_call_sequentially(tool_calls): | ||
| for index, call in enumerate(tool_calls): | ||
| if event := await handle_call_or_result( | ||
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| index, | ||
| ): | ||
| yield event | ||
|
|
||
| # We append the results at the end, rather than as they are received, to retain a consistent ordering | ||
| # This is mostly just to simplify testing | ||
| output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) | ||
| output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) | ||
| else: | ||
| tasks = [ | ||
| asyncio.create_task( | ||
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| name=call.tool_name, | ||
| ) | ||
| for call in tool_calls | ||
| ] | ||
|
|
||
| pending = tasks | ||
| while pending: | ||
| done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) | ||
| for task in done: | ||
| index = tasks.index(task) | ||
| if event := await handle_call_or_result(coro_or_task=task, index=index): | ||
| yield event | ||
|
|
||
| for k in sorted(deferred_calls_by_index): | ||
| output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k]) | ||
| # We append the results at the end, rather than as they are received, to retain a consistent ordering | ||
| # This is mostly just to simplify testing | ||
| output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) | ||
| output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) | ||
|
|
||
| for k in sorted(deferred_calls_by_index): | ||
| output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k]) | ||
| finally: | ||
| # Reset the context variable | ||
| _usage_increment_lock_ctx_var.reset(token) | ||
|
|
||
|
|
||
| async def _call_tool( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1178,7 +1178,7 @@ async def _responses_create( | |
| truncation=model_settings.get('openai_truncation', NOT_GIVEN), | ||
| timeout=model_settings.get('timeout', NOT_GIVEN), | ||
| service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), | ||
| previous_response_id=previous_response_id or NOT_GIVEN, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a broken merge conflict resolution! Please remove it from the diff to make sure we don't accidentally merge this into |
||
| previous_response_id=previous_response_id, | ||
| reasoning=reasoning, | ||
| user=model_settings.get('openai_user', NOT_GIVEN), | ||
| text=text or NOT_GIVEN, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -198,12 +198,17 @@ def incr(self, incr_usage: RunUsage | RequestUsage) -> None: | |
| if isinstance(incr_usage, RunUsage): | ||
| self.requests += incr_usage.requests | ||
| self.tool_calls += incr_usage.tool_calls | ||
| else: | ||
| # RequestUsage: requests is a property that returns 1 | ||
| self.requests += incr_usage.requests | ||
|
||
| return _incr_usage_tokens(self, incr_usage) | ||
|
|
||
| def __add__(self, other: RunUsage | RequestUsage) -> RunUsage: | ||
| """Add two RunUsages together. | ||
|
|
||
| This is provided so it's trivial to sum usage information from multiple runs. | ||
|
|
||
| **WARNING:** this CANNOT be used to sum multiple requests without breaking some pricing calculations. | ||
DouweM marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| new_usage = copy(self) | ||
| new_usage.incr(other) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3346,7 +3346,7 @@ class Result(BaseModel): | |
| assert response_stream.usage() == snapshot( | ||
| RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1) | ||
| ) | ||
| assert run.usage() == snapshot(RunUsage(requests=1)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a breaking change we shouldn't make
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, this would be a breaking change, so I'd rather ensure at the call site where we call |
||
| assert run.usage() == snapshot(RunUsage()) | ||
| assert run.usage() == snapshot( | ||
| RunUsage(input_tokens=53, output_tokens=469, details={'reasoning_tokens': 448}, requests=1) | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to include this comment, or the next identical one