@@ -6,7 +6,8 @@ import type { InferenceProvider } from "@huggingface/inference";
66import type {
77 ChatCompletionInputMessage ,
88 ChatCompletionInputTool ,
9- ChatCompletionOutput ,
9+ ChatCompletionStreamOutput ,
10+ ChatCompletionStreamOutputDeltaToolCall ,
1011} from "@huggingface/tasks/src/tasks/chat-completion/inference" ;
1112import { version as packageVersion } from "../package.json" ;
1213import { debug } from "./utils" ;
@@ -72,52 +73,79 @@ export class McpClient {
7273 async * processSingleTurnWithTools (
7374 messages : ChatCompletionInputMessage [ ] ,
7475 opts : { exitLoopTools ?: ChatCompletionInputTool [ ] ; exitIfNoTool ?: boolean } = { }
75- ) : AsyncGenerator < ChatCompletionOutput | ChatCompletionInputMessageTool > {
76+ ) : AsyncGenerator < ChatCompletionStreamOutput | ChatCompletionInputMessageTool > {
7677 debug ( "start of single turn" ) ;
7778
78- const response = await this . client . chatCompletion ( {
79+ const stream = this . client . chatCompletionStream ( {
7980 provider : this . provider ,
8081 model : this . model ,
8182 messages,
8283 tools : opts . exitLoopTools ? [ ...opts . exitLoopTools , ...this . availableTools ] : this . availableTools ,
8384 tool_choice : "auto" ,
8485 } ) ;
8586
86- const toolCalls = response . choices [ 0 ] . message . tool_calls ;
87- if ( ! toolCalls || toolCalls . length === 0 ) {
88- if ( opts . exitIfNoTool ) {
89- return ;
87+ const firstChunkResult = await stream . next ( ) ;
88+ if ( firstChunkResult . done ) {
89+ return ;
90+ }
91+ const firstChunk = firstChunkResult . value ;
92+ const firstToolCalls = firstChunk . choices [ 0 ] ?. delta . tool_calls ;
93+ if ( ( ! firstToolCalls || firstToolCalls . length === 0 ) && opts . exitIfNoTool ) {
94+ return ;
95+ }
96+ yield firstChunk ;
97+ const message = {
98+ role : firstChunk . choices [ 0 ] . delta . role ,
99+ content : firstChunk . choices [ 0 ] . delta . content ,
100+ } satisfies ChatCompletionInputMessage ;
101+
102+ const finalToolCalls : Record < number , ChatCompletionStreamOutputDeltaToolCall > = { } ;
103+
104+ for await ( const chunk of stream ) {
105+ yield chunk ;
106+ const delta = chunk . choices [ 0 ] ?. delta ;
107+ if ( ! delta ) {
108+ continue ;
109+ }
110+ if ( delta . content ) {
111+ message . content += delta . content ;
112+ }
113+ for ( const toolCall of delta . tool_calls ?? [ ] ) {
114+ // aggregating chunks into an encoded arguments JSON object
115+ if ( ! finalToolCalls [ toolCall . index ] ) {
116+ finalToolCalls [ toolCall . index ] = toolCall ;
117+ }
118+ finalToolCalls [ toolCall . index ] . function . arguments += toolCall . function . arguments ;
90119 }
91- messages . push ( {
92- role : response . choices [ 0 ] . message . role ,
93- content : response . choices [ 0 ] . message . content ,
94- } ) ;
95- return yield response ;
96120 }
97- for ( const toolCall of toolCalls ) {
98- const toolName = toolCall . function . name ;
121+
122+ messages . push ( message ) ;
123+
124+ for ( const toolCall of Object . values ( finalToolCalls ) ) {
125+ const toolName = toolCall . function . name ?? "" ;
126+ /// TODO(Fix upstream type so this is always a string)^
99127 const toolArgs = JSON . parse ( toolCall . function . arguments ) ;
100128
101- const message : ChatCompletionInputMessageTool = {
129+ const toolMessage : ChatCompletionInputMessageTool = {
102130 role : "tool" ,
103131 tool_call_id : toolCall . id ,
104132 content : "" ,
105133 name : toolName ,
106134 } ;
107135 if ( opts . exitLoopTools ?. map ( ( t ) => t . function . name ) . includes ( toolName ) ) {
108- messages . push ( message ) ;
109- return yield message ;
136+ messages . push ( toolMessage ) ;
137+ return yield toolMessage ;
110138 }
111139 /// Get the appropriate session for this tool
112140 const client = this . clients . get ( toolName ) ;
113141 if ( client ) {
114142 const result = await client . callTool ( { name : toolName , arguments : toolArgs } ) ;
115- message . content = ( result . content as Array < { text : string } > ) [ 0 ] . text ;
143+ toolMessage . content = ( result . content as Array < { text : string } > ) [ 0 ] . text ;
116144 } else {
117- message . content = `Error: No session found for tool: ${ toolName } ` ;
145+ toolMessage . content = `Error: No session found for tool: ${ toolName } ` ;
118146 }
119- messages . push ( message ) ;
120- yield message ;
147+ messages . push ( toolMessage ) ;
148+ yield toolMessage ;
121149 }
122150 }
123151
0 commit comments