@@ -408,7 +408,6 @@ async def stream(
408408 message_history , model_settings , model_request_parameters , run_context
409409 ) as streamed_response :
410410 self ._did_stream = True
411- # Request count is incremented in _finish_handling via response.usage
412411 agent_stream = result .AgentStream [DepsT , T ](
413412 _raw_stream_response = streamed_response ,
414413 _output_schema = ctx .deps .output_schema ,
@@ -419,8 +418,6 @@ async def stream(
419418 _tool_manager = ctx .deps .tool_manager ,
420419 )
421420 yield agent_stream
422- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
423- # otherwise usage won't be properly counted:
424421 async for _ in agent_stream :
425422 pass
426423
@@ -437,7 +434,6 @@ async def _make_request(
437434
438435 model_settings , model_request_parameters , message_history , _ = await self ._prepare_request (ctx )
439436 model_response = await ctx .deps .model .request (message_history , model_settings , model_request_parameters )
440- # Request count is incremented in _finish_handling via response.usage
441437
442438 return await self ._finish_handling (ctx , model_response )
443439
@@ -895,8 +891,6 @@ async def _call_tools(
895891 tool_parts_by_index : dict [int , _messages .ModelRequestPart ] = {}
896892 user_parts_by_index : dict [int , _messages .UserPromptPart ] = {}
897893 deferred_calls_by_index : dict [int , Literal ['external' , 'unapproved' ]] = {}
898- # Lock to prevent race conditions when incrementing usage.tool_calls from concurrent tool executions
899- usage_lock = asyncio .Lock ()
900894
901895 if usage_limits .tool_calls_limit is not None :
902896 projected_usage = deepcopy (usage )
@@ -906,85 +900,76 @@ async def _call_tools(
906900 for call in tool_calls :
907901 yield _messages .FunctionToolCallEvent (call )
908902
909- # Import and set the usage lock context variable for parallel tool execution
910- from ._tool_manager import _usage_increment_lock_ctx_var # pyright: ignore[reportPrivateUsage]
911-
912- token = _usage_increment_lock_ctx_var .set (usage_lock )
913-
914- try :
915- with tracer .start_as_current_span (
916- 'running tools' ,
917- attributes = {
918- 'tools' : [call .tool_name for call in tool_calls ],
919- 'logfire.msg' : f'running { len (tool_calls )} tool{ "" if len (tool_calls ) == 1 else "s" } ' ,
920- },
921- ):
903+ with tracer .start_as_current_span (
904+ 'running tools' ,
905+ attributes = {
906+ 'tools' : [call .tool_name for call in tool_calls ],
907+ 'logfire.msg' : f'running { len (tool_calls )} tool{ "" if len (tool_calls ) == 1 else "s" } ' ,
908+ },
909+ ):
922910
923- async def handle_call_or_result (
924- coro_or_task : Awaitable [
925- tuple [
926- _messages .ToolReturnPart | _messages .RetryPromptPart ,
927- str | Sequence [_messages .UserContent ] | None ,
928- ]
911+ async def handle_call_or_result (
912+ coro_or_task : Awaitable [
913+ tuple [
914+ _messages .ToolReturnPart | _messages .RetryPromptPart ,
915+ str | Sequence [_messages .UserContent ] | None ,
929916 ]
930- | Task [
931- tuple [
932- _messages .ToolReturnPart | _messages .RetryPromptPart ,
933- str | Sequence [_messages .UserContent ] | None ,
934- ]
935- ],
936- index : int ,
937- ) -> _messages .HandleResponseEvent | None :
938- try :
939- tool_part , tool_user_content = (
940- (await coro_or_task ) if inspect .isawaitable (coro_or_task ) else coro_or_task .result ()
941- )
942- except exceptions .CallDeferred :
943- deferred_calls_by_index [index ] = 'external'
944- except exceptions .ApprovalRequired :
945- deferred_calls_by_index [index ] = 'unapproved'
946- else :
947- tool_parts_by_index [index ] = tool_part
948- if tool_user_content :
949- user_parts_by_index [index ] = _messages .UserPromptPart (content = tool_user_content )
950-
951- return _messages .FunctionToolResultEvent (tool_part , content = tool_user_content )
952-
953- if tool_manager .should_call_sequentially (tool_calls ):
954- for index , call in enumerate (tool_calls ):
955- if event := await handle_call_or_result (
956- _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
957- index ,
958- ):
959- yield event
960-
961- else :
962- tasks = [
963- asyncio .create_task (
964- _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
965- name = call .tool_name ,
966- )
967- for call in tool_calls
917+ ]
918+ | Task [
919+ tuple [
920+ _messages .ToolReturnPart | _messages .RetryPromptPart ,
921+ str | Sequence [_messages .UserContent ] | None ,
968922 ]
923+ ],
924+ index : int ,
925+ ) -> _messages .HandleResponseEvent | None :
926+ try :
927+ tool_part , tool_user_content = (
928+ (await coro_or_task ) if inspect .isawaitable (coro_or_task ) else coro_or_task .result ()
929+ )
930+ except exceptions .CallDeferred :
931+ deferred_calls_by_index [index ] = 'external'
932+ except exceptions .ApprovalRequired :
933+ deferred_calls_by_index [index ] = 'unapproved'
934+ else :
935+ tool_parts_by_index [index ] = tool_part
936+ if tool_user_content :
937+ user_parts_by_index [index ] = _messages .UserPromptPart (content = tool_user_content )
969938
970- pending = tasks
971- while pending :
972- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
973- for task in done :
974- index = tasks .index (task )
975- if event := await handle_call_or_result (coro_or_task = task , index = index ):
976- yield event
939+ return _messages .FunctionToolResultEvent (tool_part , content = tool_user_content )
977940
978- # We append the results at the end, rather than as they are received, to retain a consistent ordering
979- # This is mostly just to simplify testing
980- output_parts .extend ([tool_parts_by_index [k ] for k in sorted (tool_parts_by_index )])
981- output_parts .extend ([user_parts_by_index [k ] for k in sorted (user_parts_by_index )])
941+ if tool_manager .should_call_sequentially (tool_calls ):
942+ for index , call in enumerate (tool_calls ):
943+ if event := await handle_call_or_result (
944+ _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
945+ index ,
946+ ):
947+ yield event
982948
983- for k in sorted (deferred_calls_by_index ):
984- output_deferred_calls [deferred_calls_by_index [k ]].append (tool_calls [k ])
985- finally :
986- # Reset the context variable
987- _usage_increment_lock_ctx_var .reset (token )
949+ else :
950+ tasks = [
951+ asyncio .create_task (
952+ _call_tool (tool_manager , call , tool_call_results .get (call .tool_call_id )),
953+ name = call .tool_name ,
954+ )
955+ for call in tool_calls
956+ ]
957+
958+ pending = tasks
959+ while pending :
960+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
961+ for task in done :
962+ index = tasks .index (task )
963+ if event := await handle_call_or_result (coro_or_task = task , index = index ):
964+ yield event
965+
966+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
967+ # This is mostly just to simplify testing
968+ output_parts .extend ([tool_parts_by_index [k ] for k in sorted (tool_parts_by_index )])
969+ output_parts .extend ([user_parts_by_index [k ] for k in sorted (user_parts_by_index )])
970+
971+ for k in sorted (deferred_calls_by_index ):
972+ output_deferred_calls [deferred_calls_by_index [k ]].append (tool_calls [k ])
988973
989974
990975async def _call_tool (
0 commit comments