Skip to content

Commit 15a8ec7

Browse files
authored
edits: add edit tool 'learning' for byok (#1154)
* edits: add edit tool 'learning' for byok This implements the following for extension-contributed models: - OAI/Sonnet-looking models will get apply_patch or replace_string exclusively, following the latest logic we have for 1p models. - Other models get start with replace_string and insert_edit. If they are successful at replace_string, they get it exclusively. If they are very bad at replace_string, then they only get insert_edit. - Models that get replace_string then also will get the multi_replace_string. Again depending whether they are good at it or not, they get to keep it. - Learnings are kept in a global 50-entry LRU cache. * cleanup * wip * fix types
1 parent 5fa55a3 commit 15a8ec7

File tree

13 files changed

+774
-38
lines changed

13 files changed

+774
-38
lines changed

src/extension/extension/vscode/services.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ import { IMergeConflictService } from '../../git/common/mergeConflictService';
9494
import { MergeConflictServiceImpl } from '../../git/vscode/mergeConflictServiceImpl';
9595
import { ILaunchConfigService } from '../../onboardDebug/common/launchConfigService';
9696
import { LaunchConfigService } from '../../onboardDebug/vscode/launchConfigService';
97+
import { EditToolLearningService, IEditToolLearningService } from '../../tools/common/editToolLearningService';
9798
import { ToolGroupingService } from '../../tools/common/virtualTools/toolGroupingService';
9899
import { ToolGroupingCache } from '../../tools/common/virtualTools/virtualToolGroupCache';
99100
import { IToolGroupingCache, IToolGroupingService } from '../../tools/common/virtualTools/virtualToolTypes';
@@ -170,4 +171,5 @@ export function registerServices(builder: IInstantiationServiceBuilder, extensio
170171
builder.define(IToolGroupingService, new SyncDescriptor(ToolGroupingService));
171172
builder.define(IToolGroupingCache, new SyncDescriptor(ToolGroupingCache));
172173
builder.define(IMergeConflictService, new SyncDescriptor(MergeConflictServiceImpl));
174+
builder.define(IEditToolLearningService, new SyncDescriptor(EditToolLearningService));
173175
}

src/extension/intents/node/agentIntent.ts

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import { ICodeMapperService } from '../../prompts/node/codeMapper/codeMapperServ
4242
import { TemporalContextStats } from '../../prompts/node/inline/temporalContext';
4343
import { EditCodePrompt2 } from '../../prompts/node/panel/editCodePrompt2';
4444
import { ToolResultMetadata } from '../../prompts/node/panel/toolCalling';
45+
import { IEditToolLearningService } from '../../tools/common/editToolLearningService';
4546
import { ContributedToolName, ToolName } from '../../tools/common/toolNames';
4647
import { IToolsService } from '../../tools/common/toolsService';
4748
import { VirtualTool } from '../../tools/common/virtualTools/virtualTool';
@@ -59,45 +60,55 @@ export const getAgentTools = (instaService: IInstantiationService, request: vsco
5960
const configurationService = accessor.get<IConfigurationService>(IConfigurationService);
6061
const experimentationService = accessor.get<IExperimentationService>(IExperimentationService);
6162
const endpointProvider = accessor.get<IEndpointProvider>(IEndpointProvider);
63+
const editToolLearningService = accessor.get<IEditToolLearningService>(IEditToolLearningService);
6264
const model = await endpointProvider.getChatEndpoint(request);
6365

6466
const allowTools: Record<string, boolean> = {};
65-
allowTools[ToolName.EditFile] = true;
66-
allowTools[ToolName.ReplaceString] = await modelSupportsReplaceString(model);
67-
allowTools[ToolName.ApplyPatch] = await modelSupportsApplyPatch(model) && !!toolsService.getTool(ToolName.ApplyPatch);
6867

69-
if (allowTools[ToolName.ApplyPatch] && modelCanUseApplyPatchExclusively(model) && configurationService.getExperimentBasedConfig(ConfigKey.Internal.Gpt5ApplyPatchExclusively, experimentationService)) {
70-
allowTools[ToolName.EditFile] = false;
71-
}
68+
const learned = editToolLearningService.getPreferredEndpointEditTool(model);
69+
if (learned) { // a learning-enabled (BYOK) model, we should go with what it prefers
70+
allowTools[ToolName.EditFile] = learned.includes(ToolName.EditFile);
71+
allowTools[ToolName.ReplaceString] = learned.includes(ToolName.ReplaceString);
72+
allowTools[ToolName.MultiReplaceString] = learned.includes(ToolName.MultiReplaceString);
73+
allowTools[ToolName.ApplyPatch] = learned.includes(ToolName.ApplyPatch);
74+
} else {
75+
allowTools[ToolName.EditFile] = true;
76+
allowTools[ToolName.ReplaceString] = await modelSupportsReplaceString(model);
77+
allowTools[ToolName.ApplyPatch] = await modelSupportsApplyPatch(model) && !!toolsService.getTool(ToolName.ApplyPatch);
78+
79+
if (allowTools[ToolName.ApplyPatch] && modelCanUseApplyPatchExclusively(model) && configurationService.getExperimentBasedConfig(ConfigKey.Internal.Gpt5ApplyPatchExclusively, experimentationService)) {
80+
allowTools[ToolName.EditFile] = false;
81+
}
7282

73-
if (model.family === 'grok-code') {
74-
const treatment = experimentationService.getTreatmentVariable<string>('copilotchat.hiddenModelBEditTool');
75-
switch (treatment) {
76-
case 'with_replace_string':
77-
allowTools[ToolName.ReplaceString] = true;
78-
allowTools[ToolName.MultiReplaceString] = configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceStringGrok, experimentationService);
79-
allowTools[ToolName.EditFile] = true;
80-
break;
81-
case 'only_replace_string':
82-
allowTools[ToolName.ReplaceString] = true;
83-
allowTools[ToolName.MultiReplaceString] = configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceStringGrok, experimentationService);
84-
allowTools[ToolName.EditFile] = false;
85-
break;
86-
case 'control':
87-
default:
88-
allowTools[ToolName.ReplaceString] = false;
89-
allowTools[ToolName.EditFile] = true;
83+
if (model.family === 'grok-code') {
84+
const treatment = experimentationService.getTreatmentVariable<string>('copilotchat.hiddenModelBEditTool');
85+
switch (treatment) {
86+
case 'with_replace_string':
87+
allowTools[ToolName.ReplaceString] = true;
88+
allowTools[ToolName.MultiReplaceString] = configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceStringGrok, experimentationService);
89+
allowTools[ToolName.EditFile] = true;
90+
break;
91+
case 'only_replace_string':
92+
allowTools[ToolName.ReplaceString] = true;
93+
allowTools[ToolName.MultiReplaceString] = configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceStringGrok, experimentationService);
94+
allowTools[ToolName.EditFile] = false;
95+
break;
96+
case 'control':
97+
default:
98+
allowTools[ToolName.ReplaceString] = false;
99+
allowTools[ToolName.EditFile] = true;
100+
}
90101
}
91-
}
92102

93-
if (await modelCanUseReplaceStringExclusively(model)) {
94-
allowTools[ToolName.ReplaceString] = true;
95-
allowTools[ToolName.EditFile] = false;
96-
}
103+
if (await modelCanUseReplaceStringExclusively(model)) {
104+
allowTools[ToolName.ReplaceString] = true;
105+
allowTools[ToolName.EditFile] = false;
106+
}
97107

98-
if (allowTools[ToolName.ReplaceString]) {
99-
if (await modelSupportsMultiReplaceString(model) && configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceString, experimentationService)) {
100-
allowTools[ToolName.MultiReplaceString] = true;
108+
if (allowTools[ToolName.ReplaceString]) {
109+
if (await modelSupportsMultiReplaceString(model) && configurationService.getExperimentBasedConfig(ConfigKey.Internal.MultiReplaceString, experimentationService)) {
110+
allowTools[ToolName.MultiReplaceString] = true;
111+
}
101112
}
102113
}
103114

src/extension/test/node/services.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import { IPromptVariablesService, NullPromptVariablesService } from '../../promp
4444
import { ITodoListContextProvider, TodoListContextProvider } from '../../prompt/node/todoListContextProvider';
4545
import { CodeMapperService, ICodeMapperService } from '../../prompts/node/codeMapper/codeMapperService';
4646
import { FixCookbookService, IFixCookbookService } from '../../prompts/node/inline/fixCookbookService';
47+
import { EditToolLearningService, IEditToolLearningService } from '../../tools/common/editToolLearningService';
4748
import { IToolsService } from '../../tools/common/toolsService';
4849
import { ToolGroupingService } from '../../tools/common/virtualTools/toolGroupingService';
4950
import '../../tools/node/allTools';
@@ -100,5 +101,6 @@ export function createExtensionUnitTestingServices(disposables: Pick<DisposableS
100101
testingServiceCollection.define(IEmbeddingsComputer, new SyncDescriptor(RemoteEmbeddingsComputer));
101102
testingServiceCollection.define(ITodoListContextProvider, new SyncDescriptor(TodoListContextProvider));
102103
testingServiceCollection.define(ILanguageModelServer, new SyncDescriptor(MockLanguageModelServer));
104+
testingServiceCollection.define(IEditToolLearningService, new SyncDescriptor(EditToolLearningService));
103105
return testingServiceCollection;
104106
}

src/extension/test/vscode-node/services.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ import { PromptVariablesServiceImpl } from '../../prompt/vscode-node/promptVaria
101101
import { CodeMapperService, ICodeMapperService } from '../../prompts/node/codeMapper/codeMapperService';
102102
import { FixCookbookService, IFixCookbookService } from '../../prompts/node/inline/fixCookbookService';
103103
import { WorkspaceMutationManager } from '../../testing/node/setupTestsFileManager';
104+
import { EditToolLearningService, IEditToolLearningService } from '../../tools/common/editToolLearningService';
104105
import { IToolsService, NullToolsService } from '../../tools/common/toolsService';
105106
import { ToolGroupingService } from '../../tools/common/virtualTools/toolGroupingService';
106107
import { ToolGroupingCache } from '../../tools/common/virtualTools/virtualToolGroupCache';
@@ -148,6 +149,7 @@ export function createExtensionTestingServices(): TestingServiceCollection {
148149
testingServiceCollection.define(INaiveChunkingService, new SyncDescriptor(NaiveChunkingService));
149150
testingServiceCollection.define(ILinkifyService, new SyncDescriptor(LinkifyService));
150151
testingServiceCollection.define(ITestGenInfoStorage, new SyncDescriptor(TestGenInfoStorage));
152+
testingServiceCollection.define(IEditToolLearningService, new SyncDescriptor(EditToolLearningService));
151153
testingServiceCollection.define(IDebugCommandToConfigConverter, new SyncDescriptor(DebugCommandToConfigConverter));
152154
testingServiceCollection.define(ILaunchConfigService, new SyncDescriptor(LaunchConfigService));
153155
testingServiceCollection.define(IDebuggableCommandIdentifier, new SyncDescriptor(DebuggableCommandIdentifier));
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
* Licensed under the MIT License. See License.txt in the project root for license information.
4+
*--------------------------------------------------------------------------------------------*/
5+
6+
import type { LanguageModelChat } from 'vscode';
7+
import { IEndpointProvider } from '../../../platform/endpoint/common/endpointProvider';
8+
import { IVSCodeExtensionContext } from '../../../platform/extContext/common/extensionContext';
9+
import { IChatEndpoint } from '../../../platform/networking/common/networking';
10+
import { createServiceIdentifier } from '../../../util/common/services';
11+
import { LRUCache } from '../../../util/vs/base/common/map';
12+
import { mapValues } from '../../../util/vs/base/common/objects';
13+
import { EditTools as _EditTools, EDIT_TOOL_LEARNING_STATES, IEditToolLearningData, LearningConfig, State } from './editToolLearningStates';
14+
import { ToolName } from './toolNames';
15+
16+
export type EditTools = _EditTools;
17+
18+
const CACHE_STORAGE_KEY = 'editToolLearning_cache';
19+
20+
function mapToolsRecord<I, O>(record: { [K in EditTools]?: I }, fn: (input: I, tool: EditTools) => O) {
21+
return mapValues(record, (value, key) => fn(value!, key as EditTools)) as { [K in EditTools]?: O };
22+
}
23+
24+
interface IStoredToolData {
25+
state: State;
26+
tools: { [K in EditTools]?: { successBitset: string; attempts: number } };
27+
}
28+
29+
export const IEditToolLearningService = createServiceIdentifier<IEditToolLearningService>('IEditToolLearningService');
30+
31+
export interface IEditToolLearningService {
32+
readonly _serviceBrand: undefined;
33+
getPreferredEditTool(model: LanguageModelChat): Promise<EditTools[] | undefined>;
34+
getPreferredEndpointEditTool(model: IChatEndpoint): EditTools[] | undefined;
35+
didMakeEdit(model: LanguageModelChat, tool: EditTools, success: boolean): void;
36+
}
37+
38+
function addToWindow(window: bigint, bit: bigint): bigint {
39+
// Shift left to make room for new bit, add the bit, then mask to WINDOW_SIZE
40+
const mask = (1n << BigInt(LearningConfig.WINDOW_SIZE)) - 1n;
41+
return ((window << 1n) | bit) & mask;
42+
}
43+
44+
export class EditToolLearningService implements IEditToolLearningService {
45+
readonly _serviceBrand: undefined;
46+
47+
private _cache?: LRUCache<string, IEditToolLearningData>;
48+
49+
constructor(
50+
@IVSCodeExtensionContext private readonly _context: IVSCodeExtensionContext,
51+
@IEndpointProvider private readonly _endpointProvider: IEndpointProvider,
52+
) { }
53+
54+
async getPreferredEditTool(model: LanguageModelChat): Promise<EditTools[] | undefined> {
55+
const endpoint = await this._endpointProvider.getChatEndpoint(model);
56+
return this.getPreferredEndpointEditTool(endpoint);
57+
}
58+
59+
getPreferredEndpointEditTool(endpoint: IChatEndpoint): EditTools[] | undefined {
60+
if (!endpoint.isExtensionContributed) {
61+
return undefined;
62+
}
63+
64+
const hardcoded = this._getHardcodedPreferences(endpoint.model);
65+
if (hardcoded) {
66+
return hardcoded;
67+
}
68+
69+
const learningData = this._getModelLearningData(endpoint.model);
70+
return this._computePreferences(learningData);
71+
}
72+
73+
async didMakeEdit(model: LanguageModelChat, tool: EditTools, success: boolean): Promise<void> {
74+
const endpoint = await this._endpointProvider.getChatEndpoint(model);
75+
76+
if (!endpoint.isExtensionContributed || this._getHardcodedPreferences(endpoint.family)) {
77+
return;
78+
}
79+
80+
const learningData = this._getModelLearningData(model.id);
81+
this._recordEdit(learningData, tool, success);
82+
await this._saveModelLearningData(model.id, learningData);
83+
}
84+
85+
private _getHardcodedPreferences(family: string): EditTools[] | undefined {
86+
const lowerFamily = family.toLowerCase();
87+
88+
if (lowerFamily.includes('gpt') || lowerFamily.includes('openai')) {
89+
return [ToolName.ApplyPatch];
90+
}
91+
92+
if (lowerFamily.includes('sonnet')) {
93+
return [ToolName.ReplaceString, ToolName.MultiReplaceString];
94+
}
95+
96+
return undefined;
97+
}
98+
99+
private _computePreferences(data: IEditToolLearningData): EditTools[] | undefined {
100+
return EDIT_TOOL_LEARNING_STATES[data.state].allowedTools;
101+
}
102+
103+
private _checkStateTransitions(data: IEditToolLearningData): State {
104+
const currentConfig = EDIT_TOOL_LEARNING_STATES[data.state];
105+
106+
for (const [targetState, condition] of Object.entries(currentConfig.transitions)) {
107+
if (condition(data)) {
108+
return Number(targetState) as State;
109+
}
110+
}
111+
112+
return data.state; // No transition
113+
}
114+
115+
private _recordEdit(data: IEditToolLearningData, tool: EditTools, success: boolean): void {
116+
const successBit = success ? 1n : 0n;
117+
const toolData = (data.tools[tool] ??= { successBitset: 0n, attempts: 0 });
118+
toolData.successBitset = addToWindow(toolData.successBitset, successBit);
119+
toolData.attempts++;
120+
121+
const newState = this._checkStateTransitions(data);
122+
if (newState !== data.state) {
123+
data.state = newState;
124+
data.tools = {};
125+
}
126+
}
127+
128+
private _getCache(): LRUCache<string, IEditToolLearningData> {
129+
if (!this._cache) {
130+
this._cache = this._loadCacheFromStorage();
131+
}
132+
return this._cache;
133+
}
134+
135+
private _loadCacheFromStorage(): LRUCache<string, IEditToolLearningData> {
136+
const cache = new LRUCache<string, IEditToolLearningData>(LearningConfig.CACHE_SIZE);
137+
const storedCacheData = this._context.globalState.get<{ entries: [string, IStoredToolData][] }>(CACHE_STORAGE_KEY);
138+
139+
if (!storedCacheData?.entries) {
140+
return cache;
141+
}
142+
143+
for (const [modelId, storedData] of storedCacheData.entries) {
144+
const data: IEditToolLearningData = {
145+
state: storedData.state,
146+
tools: mapToolsRecord(storedData.tools, r => ({
147+
successBitset: BigInt(r.successBitset),
148+
attempts: r.attempts,
149+
})),
150+
};
151+
cache.set(modelId, data);
152+
}
153+
154+
return cache;
155+
}
156+
157+
private async _saveCacheToStorage(): Promise<void> {
158+
if (!this._cache) {
159+
return;
160+
}
161+
162+
const entries: [string, IStoredToolData][] = Array.from(this._cache.entries(), ([modelId, data]) => {
163+
const storedData = {
164+
state: data.state,
165+
tools: mapToolsRecord(data.tools, r => ({
166+
successBitset: '0x' + r.successBitset.toString(16),
167+
attempts: r.attempts
168+
})),
169+
};
170+
return [modelId, storedData];
171+
});
172+
173+
await this._context.globalState.update(CACHE_STORAGE_KEY, { entries });
174+
}
175+
176+
private async _saveModelLearningData(modelId: string, data: IEditToolLearningData): Promise<void> {
177+
const cache = this._getCache();
178+
cache.set(modelId, data);
179+
await this._saveCacheToStorage();
180+
}
181+
182+
private _getModelLearningData(modelId: string): IEditToolLearningData {
183+
const cache = this._getCache();
184+
185+
let data = cache.get(modelId);
186+
if (!data) {
187+
data = { state: State.Initial, tools: {} };
188+
cache.set(modelId, data);
189+
}
190+
return data;
191+
}
192+
}

0 commit comments

Comments
 (0)