Skip to content

Commit aa49d3d

Browse files
committed
improvements
1 parent d936c51 commit aa49d3d

File tree

4 files changed

+51
-14
lines changed

4 files changed

+51
-14
lines changed

packages/mcp-client/cli.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@ import { join } from "node:path";
44
import { homedir } from "node:os";
55
import { Agent } from "./src";
66
import type { StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js";
7+
import { ANSI } from "./src/utils";
8+
import type { InferenceProvider } from "@huggingface/inference";
79

8-
const ANSI = {
9-
BLUE: "\x1b[34m",
10-
GREEN: "\x1b[32m",
11-
RED: "\x1b[31m",
12-
RESET: "\x1b[0m",
13-
};
10+
const MODEL_ID = process.env.MODEL_ID ?? "Qwen/Qwen2.5-72B-Instruct";
11+
const PROVIDER = (process.env.PROVIDER as InferenceProvider) ?? "together";
1412

1513
const SERVERS: StdioServerParameters[] = [
1614
{
@@ -35,8 +33,8 @@ async function main() {
3533
}
3634

3735
const agent = new Agent({
38-
provider: "together",
39-
model: "Qwen/Qwen2.5-72B-Instruct",
36+
provider: PROVIDER,
37+
model: MODEL_ID,
4038
apiKey: process.env.HF_TOKEN,
4139
servers: SERVERS,
4240
});
@@ -58,7 +56,7 @@ async function main() {
5856

5957
while (true) {
6058
const input = await rl.question("> ");
61-
for await (const response of agent.processSingleTurn(input)) {
59+
for await (const response of agent.run(input)) {
6260
if ("choices" in response) {
6361
stdout.write(response.choices[0].message.content ?? "");
6462
stdout.write("\n\n");

packages/mcp-client/src/Agent.ts

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ import type { InferenceProvider } from "@huggingface/inference";
22
import type { ChatCompletionInputMessageTool } from "./McpClient";
33
import { McpClient } from "./McpClient";
44
import type { ChatCompletionInputMessage, ChatCompletionOutput } from "@huggingface/tasks";
5+
import type { ChatCompletionInputTool } from "@huggingface/tasks/src/tasks/chat-completion/inference";
56
import type { StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio";
7+
import { debug } from "./utils";
68

79
const DEFAULT_SYSTEM_PROMPT = `
810
You are an agent - please keep going until the user’s query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved, or if you need more info from the user to solve the problem.
@@ -12,6 +14,23 @@ If you are not sure about anything pertaining to the user’s request, use your
1214
You MUST plan extensively before each function call, and reflect extensively on the outcomes of the previous function calls. DO NOT do this entire process by making function calls only, as this can impair your ability to solve the problem and think insightfully.
1315
`.trim();
1416

17+
/**
18+
* Max number of tool calling + chat completion steps in response to a single user query.
19+
*/
20+
const MAX_NUM_TURNS = 10;
21+
22+
const taskCompletionTool: ChatCompletionInputTool = {
23+
type: "function",
24+
function: {
25+
name: "task_complete",
26+
description: "Call this tool when the task given by the user is complete",
27+
parameters: {
28+
type: "object",
29+
properties: {},
30+
},
31+
},
32+
};
33+
1534
export class Agent extends McpClient {
1635
private readonly servers: StdioServerParameters[];
1736
protected messages: ChatCompletionInputMessage[];
@@ -41,14 +60,22 @@ export class Agent extends McpClient {
4160
return this.addMcpServers(this.servers);
4261
}
4362

44-
async *processSingleTurn(input: string): AsyncGenerator<ChatCompletionOutput | ChatCompletionInputMessageTool> {
63+
async *run(input: string): AsyncGenerator<ChatCompletionOutput | ChatCompletionInputMessageTool> {
4564
this.messages.push({
4665
role: "user",
4766
content: input,
4867
});
4968

50-
while (this.messages.at(-1)?.role !== "assistant") {
51-
yield* this.processSingleTurnWithTools(this.messages);
69+
let i = 0;
70+
while (
71+
!(this.messages.at(-1)?.role === "tool" && this.messages.at(-1)?.name === taskCompletionTool.function.name) &&
72+
i < MAX_NUM_TURNS
73+
) {
74+
yield* this.processSingleTurnWithTools(this.messages, {
75+
taskCompletionTool,
76+
});
77+
i++;
78+
debug("current role", this.messages.at(-1)?.role);
5279
}
5380
}
5481
}

packages/mcp-client/src/McpClient.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,16 @@ export class McpClient {
7070
}
7171

7272
async *processSingleTurnWithTools(
73-
messages: ChatCompletionInputMessage[]
73+
messages: ChatCompletionInputMessage[],
74+
opts: { taskCompletionTool?: ChatCompletionInputTool } = {}
7475
): AsyncGenerator<ChatCompletionOutput | ChatCompletionInputMessageTool> {
7576
debug("start of single turn");
7677

7778
const response = await this.client.chatCompletion({
7879
provider: this.provider,
7980
model: this.model,
8081
messages,
81-
tools: this.availableTools,
82+
tools: opts.taskCompletionTool ? [...this.availableTools, opts.taskCompletionTool] : this.availableTools,
8283
tool_choice: "auto",
8384
});
8485

@@ -100,6 +101,10 @@ export class McpClient {
100101
content: "",
101102
name: toolName,
102103
};
104+
if (toolName === opts.taskCompletionTool?.function.name) {
105+
messages.push(message);
106+
return yield message;
107+
}
103108
/// Get the appropriate session for this tool
104109
const client = this.clients.get(toolName);
105110
if (client) {

packages/mcp-client/src/utils.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,10 @@ export function debug(...args: unknown[]): void {
33
console.debug(args);
44
}
55
}
6+
7+
export const ANSI = {
8+
BLUE: "\x1b[34m",
9+
GREEN: "\x1b[32m",
10+
RED: "\x1b[31m",
11+
RESET: "\x1b[0m",
12+
};

0 commit comments

Comments
 (0)