Skip to content

Commit a808279

Browse files
authored
first version of token counting API (microsoft#210177)
microsoft#206265
1 parent 004a2c4 commit a808279

File tree

6 files changed

+66
-1
lines changed

6 files changed

+66
-1
lines changed

src/vs/workbench/api/browser/mainThreadLanguageModels.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
6060
} finally {
6161
this._pendingProgress.delete(requestId);
6262
}
63-
}
63+
},
64+
provideTokenCount: (str, token) => {
65+
return this._proxy.$provideTokenLength(handle, str, token);
66+
},
6467
}));
6568
if (metadata.auth) {
6669
dipsosables.add(this._registerAuthenticationProvider(metadata.extension, metadata.auth));
@@ -119,6 +122,11 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
119122
return task;
120123
}
121124

125+
126+
$countTokens(provider: string, value: string, token: CancellationToken): Promise<number> {
127+
return this._chatProviderService.computeTokenLength(provider, value, token);
128+
}
129+
122130
private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {
123131
// This needs to be done in both MainThread & ExtHost ChatProvider
124132
const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;

src/vs/workbench/api/common/extHost.protocol.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,13 +1199,16 @@ export interface MainThreadLanguageModelsShape extends IDisposable {
11991199

12001200
$prepareChatAccess(extension: ExtensionIdentifier, providerId: string, justification?: string): Promise<ILanguageModelChatMetadata | undefined>;
12011201
$fetchResponse(extension: ExtensionIdentifier, provider: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise<any>;
1202+
1203+
$countTokens(provider: string, value: string, token: CancellationToken): Promise<number>;
12021204
}
12031205

12041206
export interface ExtHostLanguageModelsShape {
12051207
$updateLanguageModels(data: { added?: ILanguageModelChatMetadata[]; removed?: string[] }): void;
12061208
$updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void;
12071209
$provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise<any>;
12081210
$handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise<void>;
1211+
$provideTokenLength(handle: number, value: string, token: CancellationToken): Promise<number>;
12091212
}
12101213

12111214
export interface IExtensionChatAgentMetadata extends Dto<IChatAgentMetadata> {

src/vs/workbench/api/common/extHostLanguageModels.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { createDecorator } from 'vs/platform/instantiation/common/instantiation'
2121
import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService';
2222
import { IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication';
2323
import { ILogService } from 'vs/platform/log/common/log';
24+
import { Iterable } from 'vs/base/common/iterator';
2425

2526
export interface IExtHostLanguageModels extends ExtHostLanguageModels { }
2627

@@ -180,6 +181,18 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
180181
return data.provider.provideLanguageModelResponse2(messages.map(typeConvert.LanguageModelMessage.to), options, ExtensionIdentifier.toKey(from), progress, token);
181182
}
182183

184+
185+
//#region --- token counting
186+
187+
$provideTokenLength(handle: number, value: string, token: CancellationToken): Promise<number> {
188+
const data = this._languageModels.get(handle);
189+
if (!data) {
190+
return Promise.resolve(0);
191+
}
192+
return Promise.resolve(data.provider.provideTokenCount(value, token));
193+
}
194+
195+
183196
//#region --- making request
184197

185198
$updateLanguageModels(data: { added?: ILanguageModelChatMetadata[] | undefined; removed?: string[] | undefined }): void {
@@ -378,6 +391,23 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
378391
return undefined;
379392
}
380393
return list.has(data.extension);
394+
},
395+
async computeTokenLength(languageModelId: string, value: string, token?: vscode.CancellationToken): Promise<number> {
396+
397+
token ??= CancellationToken.None;
398+
399+
const data = that._allLanguageModelData.get(languageModelId);
400+
if (!data) {
401+
throw LanguageModelError.NotFound(`Language model '${languageModelId}' is unknown.`);
402+
}
403+
404+
const local = Iterable.find(that._languageModels.values(), candidate => candidate.languageModelId === languageModelId);
405+
if (local) {
406+
// stay inside the EH
407+
return local.provider.provideTokenCount(value, token);
408+
}
409+
410+
return that._proxy.$countTokens(data.identifier, value, token);
381411
}
382412
};
383413
}

src/vs/workbench/contrib/chat/common/languageModels.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ export interface ILanguageModelChatMetadata {
4040
export interface ILanguageModelChat {
4141
metadata: ILanguageModelChatMetadata;
4242
provideChatResponse(messages: IChatMessage[], from: ExtensionIdentifier, options: { [name: string]: any }, progress: IProgress<IChatResponseFragment>, token: CancellationToken): Promise<any>;
43+
provideTokenCount(str: string, token: CancellationToken): Promise<number>;
4344
}
4445

4546
export const ILanguageModelsService = createDecorator<ILanguageModelsService>('ILanguageModelsService');
@@ -57,6 +58,8 @@ export interface ILanguageModelsService {
5758
registerLanguageModelChat(identifier: string, provider: ILanguageModelChat): IDisposable;
5859

5960
makeLanguageModelChatRequest(identifier: string, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, progress: IProgress<IChatResponseFragment>, token: CancellationToken): Promise<any>;
61+
62+
computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise<number>;
6063
}
6164

6265
export class LanguageModelsService implements ILanguageModelsService {
@@ -100,4 +103,12 @@ export class LanguageModelsService implements ILanguageModelsService {
100103
}
101104
return provider.provideChatResponse(messages, from, options, progress, token);
102105
}
106+
107+
computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise<number> {
108+
const provider = this._providers.get(identifier);
109+
if (!provider) {
110+
throw new Error(`Chat response provider with identifier ${identifier} is not registered.`);
111+
}
112+
return provider.provideTokenCount(message, token);
113+
}
103114
}

src/vscode-dts/vscode.proposed.chatProvider.d.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ declare module 'vscode' {
1717
*/
1818
export interface ChatResponseProvider {
1919
provideLanguageModelResponse2(messages: LanguageModelChatMessage[], options: { [name: string]: any }, extensionId: string, progress: Progress<ChatResponseFragment>, token: CancellationToken): Thenable<any>;
20+
21+
provideTokenCount(text: string, token: CancellationToken): Thenable<number>;
2022
}
2123

2224
export interface ChatResponseProviderMetadata {

src/vscode-dts/vscode.proposed.languageModels.d.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ declare module 'vscode' {
232232
// TODO@API SYNC or ASYNC?
233233
// TODO@API future
234234
// retrieveQuota(languageModelId: string): { remaining: number; resets: Date };
235+
236+
// TODO@API SHOULD THIS BE in vscode.lm?
237+
// TODO@API should this check for access/permissions?
238+
/**
239+
*
240+
* Compute the token length for the given text
241+
* @param languageModelId
242+
* @param text
243+
* @param token
244+
*/
245+
computeTokenLength(languageModelId: string, text: string, token?: CancellationToken): Thenable<number>;
235246
}
236247

237248
export interface ExtensionContext {

0 commit comments

Comments
 (0)