Skip to content

Commit 07ebbd0

Browse files
authored
feat(participant): filter message history when it goes over maxInputTokens VSCODE-653 (#894)
1 parent 7879cf9 commit 07ebbd0

File tree

5 files changed

+214
-73
lines changed

5 files changed

+214
-73
lines changed

src/participant/participant.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,7 @@ export default class ParticipantController {
15771577
log.info('Docs chatbot created for chatId', chatId);
15781578
}
15791579

1580-
const history = PromptHistory.getFilteredHistoryForDocs({
1580+
const history = await PromptHistory.getFilteredHistoryForDocs({
15811581
connectionNames: this._getConnectionNames(),
15821582
context: context,
15831583
});

src/participant/prompts/promptBase.ts

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type {
55
ParticipantPromptProperties,
66
} from '../../telemetry/telemetryService';
77
import { PromptHistory } from './promptHistory';
8+
import { getCopilotModel } from '../model';
89
import type { ParticipantCommandType } from '../participantTypes';
910

1011
export interface PromptArgsBase {
@@ -94,34 +95,76 @@ export function isContentEmpty(
9495
return true;
9596
}
9697

97-
export abstract class PromptBase<TArgs extends PromptArgsBase> {
98-
protected abstract getAssistantPrompt(args: TArgs): string;
98+
export abstract class PromptBase<PromptArgs extends PromptArgsBase> {
99+
protected abstract getAssistantPrompt(args: PromptArgs): string;
99100

100101
protected get internalPurposeForTelemetry(): InternalPromptPurpose {
101102
return undefined;
102103
}
103104

104-
protected getUserPrompt(args: TArgs): Promise<UserPromptResponse> {
105+
protected getUserPrompt({
106+
request,
107+
}: PromptArgs): Promise<UserPromptResponse> {
105108
return Promise.resolve({
106-
prompt: args.request.prompt,
109+
prompt: request.prompt,
107110
hasSampleDocs: false,
108111
});
109112
}
110113

111-
async buildMessages(args: TArgs): Promise<ModelInput> {
112-
let historyMessages = PromptHistory.getFilteredHistory({
113-
history: args.context?.history,
114-
...args,
114+
private async _countRemainingTokens({
115+
model,
116+
assistantPrompt,
117+
requestPrompt,
118+
}: {
119+
model: vscode.LanguageModelChat | undefined;
120+
assistantPrompt: vscode.LanguageModelChatMessage;
121+
requestPrompt: string;
122+
}): Promise<number | undefined> {
123+
if (model) {
124+
const [assistantPromptTokens, userPromptTokens] = await Promise.all([
125+
model.countTokens(assistantPrompt),
126+
model.countTokens(requestPrompt),
127+
]);
128+
return model.maxInputTokens - (assistantPromptTokens + userPromptTokens);
129+
}
130+
return undefined;
131+
}
132+
133+
async buildMessages(args: PromptArgs): Promise<ModelInput> {
134+
const { context, request, databaseName, collectionName, connectionNames } =
135+
args;
136+
137+
const model = await getCopilotModel();
138+
139+
// eslint-disable-next-line new-cap
140+
const assistantPrompt = vscode.LanguageModelChatMessage.Assistant(
141+
this.getAssistantPrompt(args)
142+
);
143+
144+
const tokenLimit = await this._countRemainingTokens({
145+
model,
146+
assistantPrompt,
147+
requestPrompt: request.prompt,
148+
});
149+
150+
let historyMessages = await PromptHistory.getFilteredHistory({
151+
history: context?.history,
152+
model,
153+
tokenLimit,
154+
namespaceIsKnown:
155+
databaseName !== undefined && collectionName !== undefined,
156+
connectionNames,
115157
});
158+
116159
// If the current user's prompt is a connection name, and the last
117160
// message was to connect. We want to use the last
118161
// message they sent before the connection name as their prompt.
119-
if (args.connectionNames?.includes(args.request.prompt)) {
120-
const history = args.context?.history;
162+
if (connectionNames?.includes(request.prompt)) {
163+
const history = context?.history;
121164
if (!history) {
122165
return {
123166
messages: [],
124-
stats: this.getStats([], args, false),
167+
stats: this.getStats([], { request, context }, false),
125168
};
126169
}
127170
const previousResponse = history[
@@ -132,13 +175,11 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
132175
// Go through the history in reverse order to find the last user message.
133176
for (let i = history.length - 1; i >= 0; i--) {
134177
if (history[i] instanceof vscode.ChatRequestTurn) {
178+
request.prompt = (history[i] as vscode.ChatRequestTurn).prompt;
135179
// Rewrite the arguments so that the prompt is the last user message from history
136180
args = {
137181
...args,
138-
request: {
139-
...args.request,
140-
prompt: (history[i] as vscode.ChatRequestTurn).prompt,
141-
},
182+
request,
142183
};
143184

144185
// Remove the item from the history messages array.
@@ -150,23 +191,20 @@ export abstract class PromptBase<TArgs extends PromptArgsBase> {
150191
}
151192

152193
const { prompt, hasSampleDocs } = await this.getUserPrompt(args);
153-
const messages = [
154-
// eslint-disable-next-line new-cap
155-
vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)),
156-
...historyMessages,
157-
// eslint-disable-next-line new-cap
158-
vscode.LanguageModelChatMessage.User(prompt),
159-
];
194+
// eslint-disable-next-line new-cap
195+
const userPrompt = vscode.LanguageModelChatMessage.User(prompt);
196+
197+
const messages = [assistantPrompt, ...historyMessages, userPrompt];
160198

161199
return {
162200
messages,
163-
stats: this.getStats(messages, args, hasSampleDocs),
201+
stats: this.getStats(messages, { request, context }, hasSampleDocs),
164202
};
165203
}
166204

167205
protected getStats(
168206
messages: vscode.LanguageModelChatMessage[],
169-
{ request, context }: TArgs,
207+
{ request, context }: Pick<PromptArgsBase, 'request' | 'context'>,
170208
hasSampleDocs: boolean
171209
): ParticipantPromptProperties {
172210
return {

src/participant/prompts/promptHistory.ts

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,28 @@ export class PromptHistory {
106106
/** When passing the history to the model we only want contextual messages
107107
to be passed. This function parses through the history and returns
108108
the messages that are valuable to keep. */
109-
static getFilteredHistory({
109+
static async getFilteredHistory({
110+
model,
111+
tokenLimit,
110112
connectionNames,
111113
history,
112-
databaseName,
113-
collectionName,
114+
namespaceIsKnown,
114115
}: {
116+
model?: vscode.LanguageModelChat | undefined;
117+
tokenLimit?: number;
115118
connectionNames?: string[]; // Used to scrape the connecting messages from the history.
116119
history?: vscode.ChatContext['history'];
117-
databaseName?: string;
118-
collectionName?: string;
119-
}): vscode.LanguageModelChatMessage[] {
120+
namespaceIsKnown: boolean;
121+
}): Promise<vscode.LanguageModelChatMessage[]> {
120122
const messages: vscode.LanguageModelChatMessage[] = [];
121123

122124
if (!history) {
123125
return [];
124126
}
125127

126-
const namespaceIsKnown =
127-
databaseName !== undefined && collectionName !== undefined;
128-
for (let i = 0; i < history.length; i++) {
128+
let totalUsedTokens = 0;
129+
130+
for (let i = history.length - 1; i >= 0; i--) {
129131
const currentTurn = history[i];
130132

131133
let addedMessage: vscode.LanguageModelChatMessage | undefined;
@@ -147,16 +149,23 @@ export class PromptHistory {
147149
});
148150
}
149151
if (addedMessage) {
152+
if (tokenLimit) {
153+
totalUsedTokens += (await model?.countTokens(addedMessage)) || 0;
154+
if (totalUsedTokens > tokenLimit) {
155+
break;
156+
}
157+
}
158+
150159
messages.push(addedMessage);
151160
}
152161
}
153162

154-
return messages;
163+
return messages.reverse();
155164
}
156165

157166
/** The docs chatbot keeps its own history so we avoid any
158167
* we need to include history only since last docs message. */
159-
static getFilteredHistoryForDocs({
168+
static async getFilteredHistoryForDocs({
160169
connectionNames,
161170
context,
162171
databaseName,
@@ -166,7 +175,7 @@ export class PromptHistory {
166175
context?: vscode.ChatContext;
167176
databaseName?: string;
168177
collectionName?: string;
169-
}): vscode.LanguageModelChatMessage[] {
178+
}): Promise<vscode.LanguageModelChatMessage[]> {
170179
if (!context) {
171180
return [];
172181
}
@@ -192,8 +201,8 @@ export class PromptHistory {
192201
return this.getFilteredHistory({
193202
connectionNames,
194203
history: historySinceLastDocs.reverse(),
195-
databaseName,
196-
collectionName,
204+
namespaceIsKnown:
205+
databaseName !== undefined && collectionName !== undefined,
197206
});
198207
}
199208
}

src/participant/sampleDocuments.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ export async function getStringifiedSampleDocuments({
5959

6060
const stringifiedDocuments = toJSString(additionToPrompt);
6161

62-
// TODO: model.countTokens will sometimes return undefined - at least in tests. We should investigate why.
63-
promptInputTokens =
64-
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
62+
// Re-evaluate promptInputTokens with less documents if necessary.
63+
if (promptInputTokens > model.maxInputTokens) {
64+
promptInputTokens =
65+
(await model.countTokens(prompt + stringifiedDocuments)) || 0;
66+
}
6567

6668
// Add sample documents to the prompt only when it fits in the context window.
6769
if (promptInputTokens <= model.maxInputTokens) {

0 commit comments

Comments
 (0)