Skip to content

Commit a3af347

Browse files
committed
Switch to streaming mode
1 parent e3d39ee commit a3af347

File tree

4 files changed

+75
-29
lines changed

4 files changed

+75
-29
lines changed

packages/mcp-client/cli.ts

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,32 @@ async function main() {
6565

6666
while (true) {
6767
const input = await rl.question("> ");
68-
for await (const response of agent.run(input)) {
69-
if ("choices" in response) {
70-
stdout.write(response.choices[0].message.content ?? "");
71-
stdout.write("\n\n");
68+
for await (const chunk of agent.run(input)) {
69+
if ("choices" in chunk) {
70+
const delta = chunk.choices[0]?.delta;
71+
if (delta.content) {
72+
stdout.write(delta.content);
73+
}
74+
if (delta.tool_calls) {
75+
stdout.write(ANSI.GRAY);
76+
for (const deltaToolCall of delta.tool_calls) {
77+
if (deltaToolCall.id) {
78+
stdout.write(deltaToolCall.id);
79+
}
80+
if (deltaToolCall.function.name) {
81+
stdout.write(deltaToolCall.function.name);
82+
}
83+
if (deltaToolCall.function.arguments) {
84+
stdout.write(deltaToolCall.function.arguments);
85+
}
86+
}
87+
stdout.write(ANSI.RESET);
88+
}
7289
} else {
7390
/// Tool call info
7491
stdout.write(ANSI.GREEN);
75-
stdout.write(`Tool[${response.name}] ${response.tool_call_id}\n`);
76-
stdout.write(response.content);
92+
stdout.write(`Tool[${chunk.name}] ${chunk.tool_call_id}\n`);
93+
stdout.write(chunk.content);
7794
stdout.write(ANSI.RESET);
7895
stdout.write("\n\n");
7996
}

packages/mcp-client/src/Agent.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { InferenceProvider } from "@huggingface/inference";
22
import type { ChatCompletionInputMessageTool } from "./McpClient";
33
import { McpClient } from "./McpClient";
4-
import type { ChatCompletionInputMessage, ChatCompletionOutput } from "@huggingface/tasks";
4+
import type { ChatCompletionInputMessage, ChatCompletionStreamOutput } from "@huggingface/tasks";
55
import type { ChatCompletionInputTool } from "@huggingface/tasks/src/tasks/chat-completion/inference";
66
import type { StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio";
77
import { debug } from "./utils";
@@ -72,7 +72,7 @@ export class Agent extends McpClient {
7272
return this.addMcpServers(this.servers);
7373
}
7474

75-
async *run(input: string): AsyncGenerator<ChatCompletionOutput | ChatCompletionInputMessageTool> {
75+
async *run(input: string): AsyncGenerator<ChatCompletionStreamOutput | ChatCompletionInputMessageTool> {
7676
this.messages.push({
7777
role: "user",
7878
content: input,

packages/mcp-client/src/McpClient.ts

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import type { InferenceProvider } from "@huggingface/inference";
66
import type {
77
ChatCompletionInputMessage,
88
ChatCompletionInputTool,
9-
ChatCompletionOutput,
9+
ChatCompletionStreamOutput,
10+
ChatCompletionStreamOutputDeltaToolCall,
1011
} from "@huggingface/tasks/src/tasks/chat-completion/inference";
1112
import { version as packageVersion } from "../package.json";
1213
import { debug } from "./utils";
@@ -72,52 +73,79 @@ export class McpClient {
7273
async *processSingleTurnWithTools(
7374
messages: ChatCompletionInputMessage[],
7475
opts: { exitLoopTools?: ChatCompletionInputTool[]; exitIfNoTool?: boolean } = {}
75-
): AsyncGenerator<ChatCompletionOutput | ChatCompletionInputMessageTool> {
76+
): AsyncGenerator<ChatCompletionStreamOutput | ChatCompletionInputMessageTool> {
7677
debug("start of single turn");
7778

78-
const response = await this.client.chatCompletion({
79+
const stream = this.client.chatCompletionStream({
7980
provider: this.provider,
8081
model: this.model,
8182
messages,
8283
tools: opts.exitLoopTools ? [...opts.exitLoopTools, ...this.availableTools] : this.availableTools,
8384
tool_choice: "auto",
8485
});
8586

86-
const toolCalls = response.choices[0].message.tool_calls;
87-
if (!toolCalls || toolCalls.length === 0) {
88-
if (opts.exitIfNoTool) {
89-
return;
87+
const firstChunkResult = await stream.next();
88+
if (firstChunkResult.done) {
89+
return;
90+
}
91+
const firstChunk = firstChunkResult.value;
92+
const firstToolCalls = firstChunk.choices[0]?.delta.tool_calls;
93+
if ((!firstToolCalls || firstToolCalls.length === 0) && opts.exitIfNoTool) {
94+
return;
95+
}
96+
yield firstChunk;
97+
const message = {
98+
role: firstChunk.choices[0].delta.role,
99+
content: firstChunk.choices[0].delta.content,
100+
} satisfies ChatCompletionInputMessage;
101+
102+
const finalToolCalls: Record<number, ChatCompletionStreamOutputDeltaToolCall> = {};
103+
104+
for await (const chunk of stream) {
105+
yield chunk;
106+
const delta = chunk.choices[0]?.delta;
107+
if (!delta) {
108+
continue;
109+
}
110+
if (delta.content) {
111+
message.content += delta.content;
112+
}
113+
for (const toolCall of delta.tool_calls ?? []) {
114+
// aggregating chunks into an encoded arguments JSON object
115+
if (!finalToolCalls[toolCall.index]) {
116+
finalToolCalls[toolCall.index] = toolCall;
117+
}
118+
finalToolCalls[toolCall.index].function.arguments += toolCall.function.arguments;
90119
}
91-
messages.push({
92-
role: response.choices[0].message.role,
93-
content: response.choices[0].message.content,
94-
});
95-
return yield response;
96120
}
97-
for (const toolCall of toolCalls) {
98-
const toolName = toolCall.function.name;
121+
122+
messages.push(message);
123+
124+
for (const toolCall of Object.values(finalToolCalls)) {
125+
const toolName = toolCall.function.name ?? "";
126+
/// TODO(Fix upstream type so this is always a string)^
99127
const toolArgs = JSON.parse(toolCall.function.arguments);
100128

101-
const message: ChatCompletionInputMessageTool = {
129+
const toolMessage: ChatCompletionInputMessageTool = {
102130
role: "tool",
103131
tool_call_id: toolCall.id,
104132
content: "",
105133
name: toolName,
106134
};
107135
if (opts.exitLoopTools?.map((t) => t.function.name).includes(toolName)) {
108-
messages.push(message);
109-
return yield message;
136+
messages.push(toolMessage);
137+
return yield toolMessage;
110138
}
111139
/// Get the appropriate session for this tool
112140
const client = this.clients.get(toolName);
113141
if (client) {
114142
const result = await client.callTool({ name: toolName, arguments: toolArgs });
115-
message.content = (result.content as Array<{ text: string }>)[0].text;
143+
toolMessage.content = (result.content as Array<{ text: string }>)[0].text;
116144
} else {
117-
message.content = `Error: No session found for tool: ${toolName}`;
145+
toolMessage.content = `Error: No session found for tool: ${toolName}`;
118146
}
119-
messages.push(message);
120-
yield message;
147+
messages.push(toolMessage);
148+
yield toolMessage;
121149
}
122150
}
123151

packages/mcp-client/src/utils.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export function debug(...args: unknown[]): void {
66

77
export const ANSI = {
88
BLUE: "\x1b[34m",
9+
GRAY: "\x1b[90m",
910
GREEN: "\x1b[32m",
1011
RED: "\x1b[31m",
1112
RESET: "\x1b[0m",

0 commit comments

Comments
 (0)