Skip to content

Commit 6bd21d4

Browse files
authored
Set specific model for prompts and modes (microsoft#252893)
* add model to prompt and mode files * add diagnostics * handle model service not initialized * switch model on mode update * fix test * fix value completion range
1 parent 855c8cb commit 6bd21d4

20 files changed

+353
-121
lines changed

src/vs/workbench/contrib/chat/browser/chatInputPart.ts

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,16 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
510510
}
511511
}
512512

513+
public switchModelByName(modelName: string): boolean {
514+
const models = this.getModels();
515+
const model = models.find(m => m.metadata.name === modelName);
516+
if (model) {
517+
this.setCurrentLanguageModel(model);
518+
return true;
519+
}
520+
return false;
521+
}
522+
513523
public switchToNextModel(): void {
514524
const models = this.getModels();
515525
if (models.length > 0) {
@@ -552,6 +562,11 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
552562
this.chatModeKindKey.set(mode.kind);
553563
this._onDidChangeCurrentChatMode.fire();
554564

565+
const model = mode.model?.get();
566+
if (model) {
567+
this.switchModelByName(model);
568+
}
569+
555570
if (storeSelection) {
556571
this.storageService.store(GlobalLastChatModeKey, mode.kind, StorageScope.APPLICATION, StorageTarget.USER);
557572
}
@@ -560,10 +575,7 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
560575
private modelSupportedForDefaultAgent(model: ILanguageModelChatMetadataAndIdentifier): boolean {
561576
// Probably this logic could live in configuration on the agent, or somewhere else, if it gets more complex
562577
if (this.currentModeKind === ChatModeKind.Agent || (this.currentModeKind === ChatModeKind.Edit && this.configurationService.getValue(ChatConfiguration.Edits2Enabled))) {
563-
const supportsToolsAgent = typeof model.metadata.capabilities?.agentMode === 'undefined' || model.metadata.capabilities.agentMode;
564-
565-
// Filter out models that don't support tool calling, and models that don't support enough context to have a good experience with the tools agent
566-
return supportsToolsAgent && !!model.metadata.capabilities?.toolCalling;
578+
return ILanguageModelChatMetadata.suitableForAgentMode(model.metadata);
567579
}
568580

569581
return true;

src/vs/workbench/contrib/chat/browser/chatSelectedTools.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ export class ChatSelectedTools extends Disposable {
176176
return;
177177
}
178178
if (mode.kind === ChatModeKind.Agent && mode.customTools && mode.uri) {
179-
// apply directly to mode.
179+
// apply directly to mode file.
180180
this.updateCustomModeTools(mode.uri.get(), enablementMap);
181181
return;
182182
}

src/vs/workbench/contrib/chat/browser/chatWidget.ts

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,7 @@ export class ChatWidget extends Disposable implements IChatWidget {
18791879

18801880
private async _applyPromptMetadata(metadata: TPromptMetadata, requestInput: IChatRequestInputOptions): Promise<void> {
18811881

1882-
const { mode, tools } = metadata;
1882+
const { mode, tools, model } = metadata;
18831883

18841884
// switch to appropriate chat mode if needed
18851885
if (mode && mode !== this.input.currentModeKind) {
@@ -1894,17 +1894,21 @@ export class ChatWidget extends Disposable implements IChatWidget {
18941894
}
18951895

18961896
// if not tools to enable are present, we are done
1897-
if (tools === undefined) {
1898-
return;
1899-
}
1897+
if (tools !== undefined) {
1898+
19001899

1901-
// sanity check on the logic of the `getPromptFilesMetadata` method
1902-
// and the code above in case this block is moved around somewhere else:
1903-
// if we have some tools present, the mode must have been equal to `agent`
1904-
assert(this.input.currentModeKind === ChatModeKind.Agent, `Chat mode must be 'agent' when there are 'tools' defined, got ${this.input.currentModeKind}.`);
1900+
// sanity check on the logic of the `getPromptFilesMetadata` method
1901+
// and the code above in case this block is moved around somewhere else:
1902+
// if we have some tools present, the mode must have been equal to `agent`
1903+
assert(this.input.currentModeKind === ChatModeKind.Agent, `Chat mode must be 'agent' when there are 'tools' defined, got ${this.input.currentModeKind}.`);
19051904

1906-
const enablementMap = this.toolsService.toToolAndToolSetEnablementMap(new Set(tools));
1907-
this.input.selectedToolsModel.set(enablementMap, true);
1905+
const enablementMap = this.toolsService.toToolAndToolSetEnablementMap(new Set(tools));
1906+
this.input.selectedToolsModel.set(enablementMap, true);
1907+
}
1908+
1909+
if (model !== undefined) {
1910+
this.input.switchModelByName(model);
1911+
}
19081912
}
19091913

19101914
/**

src/vs/workbench/contrib/chat/common/chatModes.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export class ChatModeService extends Disposable implements IChatModeService {
9595
name: cachedMode.name,
9696
description: cachedMode.description,
9797
tools: cachedMode.customTools,
98+
model: cachedMode.model,
9899
body: cachedMode.body || ''
99100
};
100101
const instance = new CustomChatMode(customChatMode);
@@ -190,6 +191,7 @@ export interface IChatModeData {
190191
readonly description?: string;
191192
readonly kind: ChatModeKind;
192193
readonly customTools?: readonly string[];
194+
readonly model?: string;
193195
readonly body?: string;
194196
readonly uri?: URI;
195197
}
@@ -200,8 +202,10 @@ export interface IChatMode {
200202
readonly description: IObservable<string | undefined>;
201203
readonly kind: ChatModeKind;
202204
readonly customTools?: IObservable<readonly string[] | undefined>;
205+
readonly model?: IObservable<string | undefined>;
203206
readonly body?: IObservable<string>;
204207
readonly uri?: IObservable<URI>;
208+
205209
}
206210

207211
function isCachedChatModeData(data: unknown): data is IChatModeData {
@@ -216,6 +220,7 @@ function isCachedChatModeData(data: unknown): data is IChatModeData {
216220
(mode.description === undefined || typeof mode.description === 'string') &&
217221
(mode.customTools === undefined || Array.isArray(mode.customTools)) &&
218222
(mode.body === undefined || typeof mode.body === 'string') &&
223+
(mode.model === undefined || typeof mode.model === 'string') &&
219224
(mode.uri === undefined || (typeof mode.uri === 'object' && mode.uri !== null));
220225
}
221226

@@ -224,6 +229,7 @@ export class CustomChatMode implements IChatMode {
224229
private readonly _customToolsObservable: ISettableObservable<readonly string[] | undefined>;
225230
private readonly _bodyObservable: ISettableObservable<string>;
226231
private readonly _uriObservable: ISettableObservable<URI>;
232+
private readonly _modelObservable: ISettableObservable<string | undefined>;
227233

228234
public readonly id: string;
229235
public readonly name: string;
@@ -236,6 +242,10 @@ export class CustomChatMode implements IChatMode {
236242
return this._customToolsObservable;
237243
}
238244

245+
get model(): IObservable<string | undefined> {
246+
return this._modelObservable;
247+
}
248+
239249
get body(): IObservable<string> {
240250
return this._bodyObservable;
241251
}
@@ -253,6 +263,7 @@ export class CustomChatMode implements IChatMode {
253263
this.name = customChatMode.name;
254264
this._descriptionObservable = observableValue('description', customChatMode.description);
255265
this._customToolsObservable = observableValue('customTools', customChatMode.tools);
266+
this._modelObservable = observableValue('model', customChatMode.model);
256267
this._bodyObservable = observableValue('body', customChatMode.body);
257268
this._uriObservable = observableValue('uri', customChatMode.uri);
258269
}
@@ -265,6 +276,7 @@ export class CustomChatMode implements IChatMode {
265276
// Note- name is derived from ID, it can't change
266277
this._descriptionObservable.set(newData.description, tx);
267278
this._customToolsObservable.set(newData.tools, tx);
279+
this._modelObservable.set(newData.model, tx);
268280
this._bodyObservable.set(newData.body, tx);
269281
this._uriObservable.set(newData.uri, tx);
270282
});
@@ -277,6 +289,7 @@ export class CustomChatMode implements IChatMode {
277289
description: this.description.get(),
278290
kind: this.kind,
279291
customTools: this.customTools.get(),
292+
model: this.model.get(),
280293
body: this.body.get(),
281294
uri: this.uri.get()
282295
};

src/vs/workbench/contrib/chat/common/languageModels.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ export interface ILanguageModelChatMetadata {
146146
};
147147
}
148148

149+
export namespace ILanguageModelChatMetadata {
150+
export function suitableForAgentMode(metadata: ILanguageModelChatMetadata): boolean {
151+
const supportsToolsAgent = typeof metadata.capabilities?.agentMode === 'undefined' || metadata.capabilities.agentMode;
152+
return supportsToolsAgent && !!metadata.capabilities?.toolCalling;
153+
}
154+
}
155+
149156
export interface ILanguageModelChatResponse {
150157
stream: AsyncIterable<IChatResponseFragment | IChatResponseFragment[]>;
151158
result: Promise<any>;

src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderAutocompletion.ts

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { CancellationToken } from '../../../../../../base/common/cancellation.js
1313
import { ILanguageFeaturesService } from '../../../../../../editor/common/services/languageFeatures.js';
1414
import { CompletionContext, CompletionItem, CompletionItemInsertTextRule, CompletionItemKind, CompletionItemProvider, CompletionList } from '../../../../../../editor/common/languages.js';
1515
import { Range } from '../../../../../../editor/common/core/range.js';
16+
import { ILanguageModelChatMetadata, ILanguageModelsService } from '../../languageModels.js';
1617

1718
export class PromptHeaderAutocompletion extends Disposable implements CompletionItemProvider {
1819
/**
@@ -28,6 +29,7 @@ export class PromptHeaderAutocompletion extends Disposable implements Completion
2829
constructor(
2930
@IPromptsService private readonly promptsService: IPromptsService,
3031
@ILanguageFeaturesService private readonly languageService: ILanguageFeaturesService,
32+
@ILanguageModelsService private readonly languageModelsService: ILanguageModelsService,
3133

3234
) {
3335
super();
@@ -143,13 +145,14 @@ export class PromptHeaderAutocompletion extends Disposable implements Completion
143145
return undefined;
144146
}
145147

148+
const whilespaceAfterColon = (lineContent.substring(colonPosition.column).match(/^\s*/)?.[0].length) ?? 0;
146149
const values = this.getValueSuggestions(promptType, property);
147150
for (const value of values) {
148151
const item: CompletionItem = {
149152
label: value,
150153
kind: CompletionItemKind.Value,
151154
insertText: value,
152-
range: new Range(position.lineNumber, position.column, position.lineNumber, model.getLineMaxColumn(position.lineNumber)),
155+
range: new Range(position.lineNumber, colonPosition.column + whilespaceAfterColon + 1, position.lineNumber, model.getLineMaxColumn(position.lineNumber)),
153156
};
154157
suggestions.push(item);
155158
}
@@ -161,9 +164,9 @@ export class PromptHeaderAutocompletion extends Disposable implements Completion
161164
case PromptsType.instructions:
162165
return new Set(['applyTo', 'description']);
163166
case PromptsType.prompt:
164-
return new Set(['mode', 'tools', 'description']);
167+
return new Set(['mode', 'tools', 'description', 'model']);
165168
default:
166-
return new Set(['tools', 'description']);
169+
return new Set(['tools', 'description', 'model']);
167170
}
168171
}
169172

@@ -190,6 +193,22 @@ export class PromptHeaderAutocompletion extends Disposable implements Completion
190193
if (property === 'tools' && (promptType === PromptsType.prompt || promptType === PromptsType.mode)) {
191194
return ['[]', `['codebase', 'editFiles', 'fetch']`];
192195
}
196+
if (property === 'model' && (promptType === PromptsType.prompt || promptType === PromptsType.mode)) {
197+
return this.getModelNames(promptType === PromptsType.mode);
198+
}
193199
return [];
194200
}
201+
202+
private getModelNames(agentModeOnly: boolean): string[] {
203+
const result = [];
204+
for (const model of this.languageModelsService.getLanguageModelIds()) {
205+
const metadata = this.languageModelsService.lookupLanguageModel(model);
206+
if (metadata && metadata.isUserSelectable !== false) {
207+
if (!agentModeOnly || ILanguageModelChatMetadata.suitableForAgentMode(metadata)) {
208+
result.push(metadata.name);
209+
}
210+
}
211+
}
212+
return result;
213+
}
195214
}

src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderDiagnosticsProvider.ts

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ import { CancellationToken } from '../../../../../../base/common/cancellation.js
1111
import { ProviderInstanceManagerBase, TProviderClass } from './providerInstanceManagerBase.js';
1212
import { TDiagnostic, PromptMetadataError, PromptMetadataWarning } from '../parsers/promptHeader/diagnostics.js';
1313
import { IMarkerData, IMarkerService, MarkerSeverity } from '../../../../../../platform/markers/common/markers.js';
14+
import { PromptHeader } from '../parsers/promptHeader/promptHeader.js';
15+
import { PromptToolsMetadata } from '../parsers/promptHeader/metadata/tools.js';
16+
import { PromptModelMetadata } from '../parsers/promptHeader/metadata/model.js';
17+
import { ModeHeader } from '../parsers/promptHeader/modeHeader.js';
18+
import { ILanguageModelChatMetadata, ILanguageModelsService } from '../../languageModels.js';
19+
import { ILanguageModelToolsService } from '../../languageModelToolsService.js';
20+
import { localize } from '../../../../../../nls.js';
21+
import { ChatModeKind } from '../../constants.js';
1422

1523
/**
1624
* Unique ID of the markers provider class.
@@ -26,8 +34,16 @@ class PromptHeaderDiagnosticsProvider extends ProviderInstanceBase {
2634
model: ITextModel,
2735
@IPromptsService promptsService: IPromptsService,
2836
@IMarkerService private readonly markerService: IMarkerService,
37+
@ILanguageModelsService private readonly languageModelsService: ILanguageModelsService,
38+
@ILanguageModelToolsService private readonly languageModelToolsService: ILanguageModelToolsService,
2939
) {
3040
super(model, promptsService);
41+
this._register(languageModelsService.onDidChangeLanguageModels(() => {
42+
this.onPromptSettled(undefined, CancellationToken.None);
43+
}));
44+
this._register(languageModelToolsService.onDidChangeTools(() => {
45+
this.onPromptSettled(undefined, CancellationToken.None);
46+
}));
3147
}
3248

3349
/**
@@ -47,27 +63,95 @@ class PromptHeaderDiagnosticsProvider extends ProviderInstanceBase {
4763

4864
// header parsing process is separate from the prompt parsing one, hence
4965
// apply markers only after the header is settled and so has diagnostics
50-
header.settled.then(() => {
51-
// by the time the promise finishes, the token might have been cancelled
52-
// already due to a new 'onSettle' event, hence don't apply outdated markers
53-
if (token.isCancellationRequested) {
54-
return;
55-
}
66+
await header.settled;
67+
// by the time the promise finishes, the token might have been cancelled
68+
// already due to a new 'onSettle' event, hence don't apply outdated markers
69+
if (token.isCancellationRequested) {
70+
return;
71+
}
5672

57-
const markers: IMarkerData[] = [];
58-
for (const diagnostic of header.diagnostics) {
59-
markers.push(toMarker(diagnostic));
60-
}
73+
const markers: IMarkerData[] = [];
74+
for (const diagnostic of header.diagnostics) {
75+
markers.push(toMarker(diagnostic));
76+
}
77+
78+
if (header instanceof PromptHeader) {
79+
this.validateTools(header.metadataUtility.tools, header.metadata.mode, markers);
80+
this.validateModel(header.metadataUtility.model, header.metadata.mode, markers);
81+
} else if (header instanceof ModeHeader) {
82+
this.validateTools(header.metadataUtility.tools, ChatModeKind.Agent, markers);
83+
this.validateModel(header.metadataUtility.model, ChatModeKind.Agent, markers);
6184

62-
this.markerService.changeOne(
63-
MARKERS_OWNER_ID,
64-
this.model.uri,
65-
markers,
66-
);
67-
});
85+
}
6886

87+
this.markerService.changeOne(
88+
MARKERS_OWNER_ID,
89+
this.model.uri,
90+
markers,
91+
);
6992
return;
7093
}
94+
validateModel(modelNode: PromptModelMetadata | undefined, modeKind: ChatModeKind | undefined, markers: IMarkerData[]) {
95+
if (!modelNode || modelNode.value === undefined) {
96+
return;
97+
}
98+
const languageModes = this.languageModelsService.getLanguageModelIds();
99+
if (languageModes.length === 0) {
100+
// likely the service is not initialized yet
101+
return;
102+
}
103+
const modelMetadata = this.findModelByName(languageModes, modelNode.value);
104+
if (!modelMetadata) {
105+
markers.push({
106+
message: localize('promptHeaderDiagnosticsProvider.modelNotFound', "Unknown model '{0}'", modelNode.value),
107+
severity: MarkerSeverity.Warning,
108+
...modelNode.range,
109+
});
110+
} else if (modeKind === ChatModeKind.Agent && !ILanguageModelChatMetadata.suitableForAgentMode(modelMetadata)) {
111+
markers.push({
112+
message: localize('promptHeaderDiagnosticsProvider.modelNotSuited', "Model '{0}' is not suited for agent mode", modelNode.value),
113+
severity: MarkerSeverity.Warning,
114+
...modelNode.range,
115+
});
116+
}
117+
118+
}
119+
findModelByName(languageModes: string[], modelName: string): ILanguageModelChatMetadata | undefined {
120+
for (const model of languageModes) {
121+
const metadata = this.languageModelsService.lookupLanguageModel(model);
122+
if (metadata && metadata.isUserSelectable !== false && metadata.name === modelName) {
123+
return metadata;
124+
}
125+
}
126+
return undefined;
127+
}
128+
129+
validateTools(tools: PromptToolsMetadata | undefined, modeKind: ChatModeKind | undefined, markers: IMarkerData[]) {
130+
if (!tools || tools.value === undefined || modeKind === ChatModeKind.Ask || modeKind === ChatModeKind.Edit) {
131+
return;
132+
}
133+
const toolNames = new Set(tools.value);
134+
if (toolNames.size === 0) {
135+
return;
136+
}
137+
for (const tool of this.languageModelToolsService.getTools()) {
138+
toolNames.delete(tool.toolReferenceName ?? tool.displayName);
139+
}
140+
for (const toolSet of this.languageModelToolsService.toolSets.get()) {
141+
toolNames.delete(toolSet.referenceName);
142+
}
143+
144+
for (const toolName of toolNames) {
145+
const range = tools.getToolRange(toolName);
146+
if (range) {
147+
markers.push({
148+
message: localize('promptHeaderDiagnosticsProvider.toolNotFound', "Unknown tool '{0}'", toolName),
149+
severity: MarkerSeverity.Warning,
150+
...range,
151+
});
152+
}
153+
}
154+
}
71155

72156
/**
73157
* Returns a string representation of this object.

0 commit comments

Comments
 (0)