Skip to content

Commit 77221eb

Browse files
authored
Fix enablement of MCP tool sets and tools (#255747)
fix enablement of mcp tool sets and tools
1 parent c50533b commit 77221eb

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

src/vs/workbench/contrib/chat/browser/actions/chatToolPicker.ts

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import { assertNever } from '../../../../../base/common/assert.js';
66
import { Codicon } from '../../../../../base/common/codicons.js';
77
import { diffSets } from '../../../../../base/common/collections.js';
88
import { Event } from '../../../../../base/common/event.js';
9-
import { Iterable } from '../../../../../base/common/iterator.js';
109
import { DisposableStore } from '../../../../../base/common/lifecycle.js';
1110
import { ThemeIcon } from '../../../../../base/common/themables.js';
1211
import { assertType } from '../../../../../base/common/types.js';
@@ -108,7 +107,9 @@ export async function showToolsPicker(
108107
if (!toolsEntries) {
109108
const defaultEntries = new Map();
110109
for (const tool of toolsService.getTools()) {
111-
defaultEntries.set(tool, false);
110+
if (tool.canBeReferencedInPrompt) {
111+
defaultEntries.set(tool, false);
112+
}
112113
}
113114
for (const toolSet of toolsService.toolSets.get()) {
114115
defaultEntries.set(toolSet, false);
@@ -203,8 +204,8 @@ export async function showToolsPicker(
203204
} else {
204205
// stash the MCP toolset into the bucket item
205206
bucket.toolset = toolSetOrTool;
207+
bucket.picked = picked;
206208
}
207-
208209
} else if (toolSetOrTool.canBeReferencedInPrompt) {
209210
bucket.children.push({
210211
parent: bucket,
@@ -216,10 +217,6 @@ export async function showToolsPicker(
216217
indented: true,
217218
});
218219
}
219-
220-
if (picked) {
221-
bucket.picked = true;
222-
}
223220
}
224221

225222
for (const bucket of [builtinBucket, userBucket]) {
@@ -228,6 +225,21 @@ export async function showToolsPicker(
228225
}
229226
}
230227

228+
// set the checkmarks in the UI:
229+
// bucket is checked if at least one of the children is checked
230+
// tool is checked if the bucket is checked or the tool itself is checked
231+
for (const bucket of toolBuckets.values()) {
232+
if (bucket.picked) {
233+
// check all children if the bucket is checked
234+
for (const child of bucket.children) {
235+
child.picked = true;
236+
}
237+
} else {
238+
// check the bucket if one of the children is checked
239+
bucket.picked = bucket.children.some(child => child.picked);
240+
}
241+
}
242+
231243
const store = new DisposableStore();
232244

233245
const picks: (MyPick | IQuickPickSeparator)[] = [];
@@ -385,22 +397,12 @@ export async function showToolsPicker(
385397

386398
store.dispose();
387399

388-
const mcpToolSets = new Set<ToolSet>();
389-
400+
// in the result, a MCP toolset is only enabled if all tools in the toolset are enabled
390401
for (const item of toolsService.toolSets.get()) {
391402
if (item.source.type === 'mcp') {
392-
mcpToolSets.add(item);
393-
394-
if (Iterable.every(item.getTools(), tool => result.get(tool))) {
395-
// ALL tools from the MCP tool set are here, replace them with just the toolset
396-
// but only when computing the final result
397-
for (const tool of item.getTools()) {
398-
result.delete(tool);
399-
}
400-
result.set(item, true);
401-
}
403+
const toolsInSet = Array.from(item.getTools());
404+
result.set(item, toolsInSet.every(tool => result.get(tool)));
402405
}
403406
}
404-
405407
return didAccept ? result : undefined;
406408
}

src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,23 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
492492
toToolAndToolSetEnablementMap(enabledToolOrToolSetNames: readonly string[] | undefined): Map<ToolSet | IToolData, boolean> {
493493
const toolOrToolSetNames = enabledToolOrToolSetNames ? new Set(enabledToolOrToolSetNames) : undefined;
494494
const result = new Map<ToolSet | IToolData, boolean>();
495-
for (const tool of this._tools.values()) {
496-
result.set(tool.data, tool.data.toolReferenceName !== undefined && (toolOrToolSetNames === undefined || toolOrToolSetNames.has(tool.data.toolReferenceName)));
495+
for (const tool of this.getTools()) {
496+
if (tool.canBeReferencedInPrompt) {
497+
result.set(tool, toolOrToolSetNames === undefined || toolOrToolSetNames.has(tool.toolReferenceName ?? tool.displayName));
498+
}
497499
}
498500
for (const toolSet of this._toolSets) {
499-
result.set(toolSet, (toolOrToolSetNames === undefined || toolOrToolSetNames.has(toolSet.referenceName)));
501+
const enabled = toolOrToolSetNames === undefined || toolOrToolSetNames.has(toolSet.referenceName);
502+
result.set(toolSet, enabled);
503+
504+
// if a mcp toolset is enabled, all tools in it are enabled
505+
if (enabled && toolSet.source.type === 'mcp') {
506+
for (const tool of toolSet.getTools()) {
507+
if (tool.canBeReferencedInPrompt) {
508+
result.set(tool, enabled);
509+
}
510+
}
511+
}
500512
}
501513
return result;
502514
}

src/vs/workbench/contrib/chat/browser/promptSyntax/promptFileRewriter.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { ICodeEditorService } from '../../../../../editor/browser/services/codeE
99
import { EditOperation } from '../../../../../editor/common/core/editOperation.js';
1010
import { Range } from '../../../../../editor/common/core/range.js';
1111
import { ITextModel } from '../../../../../editor/common/model.js';
12-
import { IToolAndToolSetEnablementMap, ToolSet } from '../../common/languageModelToolsService.js';
12+
import { IToolAndToolSetEnablementMap, IToolData, ToolSet } from '../../common/languageModelToolsService.js';
1313
import { IPromptsService } from '../../common/promptSyntax/service/promptsService.js';
1414

1515
export class PromptFileRewriter {
@@ -59,11 +59,19 @@ export class PromptFileRewriter {
5959
model.pushStackElement();
6060
return;
6161
}
62+
const toolsCoveredBySets = new Set<IToolData>();
63+
for (const [item, picked] of newTools) {
64+
if (picked && item instanceof ToolSet) {
65+
for (const tool of item.getTools()) {
66+
toolsCoveredBySets.add(tool);
67+
}
68+
}
69+
}
6270
for (const [item, picked] of newTools) {
6371
if (picked) {
6472
if (item instanceof ToolSet) {
6573
newToolNames.push(item.referenceName);
66-
} else {
74+
} else if (!toolsCoveredBySets.has(item)) {
6775
newToolNames.push(item.toolReferenceName ?? item.displayName);
6876
}
6977
}

0 commit comments

Comments
 (0)