Skip to content

Commit 506cce1

Browse files
bhavyausCopilot
andauthored
Add CAPI text-3-small endpoint support for embeddings (#1037)
* Add CAPI text-3-small endpoint support for embeddings * Update src/platform/endpoint/common/endpointProvider.ts Co-authored-by: Copilot <[email protected]> * Remove todo comment * Update methods * Remove unused options parameter from fetchResponseWithBatches call --------- Co-authored-by: Copilot <[email protected]>
1 parent 014f797 commit 506cce1

File tree

15 files changed

+328
-15
lines changed

15 files changed

+328
-15
lines changed

src/extension/context/node/resolvers/extensionApi.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ export class VSCodeAPIContextElement extends PromptElement<VSCodeAPIContextProps
9595
return [];
9696
}
9797

98-
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [this.props.query], {}, new TelemetryCorrelationId('VSCodeAPIContextElement::getSnippets'), token);
98+
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [this.props.query], { endpointType: 'capi' }, new TelemetryCorrelationId('VSCodeAPIContextElement::getSnippets'), token);
9999
return this.apiEmbeddingsIndex.nClosestValues(embeddingResult.values[0], 5);
100100
}
101101

src/extension/prompt/vscode-node/endpointProviderImpl.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ import { ConfigKey, IConfigurationService } from '../../../platform/configuratio
99
import { AutoChatEndpoint } from '../../../platform/endpoint/common/autoChatEndpoint';
1010
import { IAutomodeService } from '../../../platform/endpoint/common/automodeService';
1111
import { ICAPIClientService } from '../../../platform/endpoint/common/capiClient';
12-
import { ChatEndpointFamily, IChatModelInformation, ICompletionModelInformation, IEndpointProvider } from '../../../platform/endpoint/common/endpointProvider';
12+
import { ChatEndpointFamily, EmbeddingsEndpointFamily, IChatModelInformation, ICompletionModelInformation, IEmbeddingModelInformation, IEndpointProvider } from '../../../platform/endpoint/common/endpointProvider';
1313
import { CopilotChatEndpoint } from '../../../platform/endpoint/node/copilotChatEndpoint';
14+
import { EmbeddingEndpoint } from '../../../platform/endpoint/node/embeddingsEndpoint';
1415
import { IModelMetadataFetcher, ModelMetadataFetcher } from '../../../platform/endpoint/node/modelMetadataFetcher';
1516
import { applyExperimentModifications, ExperimentConfig, getCustomDefaultModelExperimentConfig, ProxyExperimentEndpoint } from '../../../platform/endpoint/node/proxyExperimentEndpoint';
1617
import { ExtensionContributedChatEndpoint } from '../../../platform/endpoint/vscode-node/extChatEndpoint';
1718
import { IEnvService } from '../../../platform/env/common/envService';
1819
import { ILogService } from '../../../platform/log/common/logService';
1920
import { IFetcherService } from '../../../platform/networking/common/fetcherService';
20-
import { IChatEndpoint } from '../../../platform/networking/common/networking';
21+
import { IChatEndpoint, IEmbeddingsEndpoint } from '../../../platform/networking/common/networking';
2122
import { IRequestLogger } from '../../../platform/requestLogger/node/requestLogger';
2223
import { IExperimentationService } from '../../../platform/telemetry/common/nullExperimentationService';
2324
import { ITelemetryService } from '../../../platform/telemetry/common/telemetry';
@@ -30,6 +31,7 @@ export class ProductionEndpointProvider implements IEndpointProvider {
3031
declare readonly _serviceBrand: undefined;
3132

3233
private _chatEndpoints: Map<string, IChatEndpoint> = new Map();
34+
private _embeddingEndpoints: Map<string, IEmbeddingsEndpoint> = new Map();
3335
private readonly _modelFetcher: IModelMetadataFetcher;
3436

3537
constructor(
@@ -144,6 +146,24 @@ export class ProductionEndpointProvider implements IEndpointProvider {
144146
return endpoint;
145147
}
146148

149+
async getEmbeddingsEndpoint(family?: EmbeddingsEndpointFamily): Promise<IEmbeddingsEndpoint> {
150+
this._logService.trace(`Resolving embedding model`);
151+
const modelMetadata = await this._modelFetcher.getEmbeddingsModel('text-embedding-3-small');
152+
const model = await this.getOrCreateEmbeddingEndpointInstance(modelMetadata);
153+
this._logService.trace(`Resolved embedding model`);
154+
return model;
155+
}
156+
157+
private async getOrCreateEmbeddingEndpointInstance(modelMetadata: IEmbeddingModelInformation): Promise<IEmbeddingsEndpoint> {
158+
const modelId = 'text-embedding-3-small';
159+
let embeddingEndpoint = this._embeddingEndpoints.get(modelId);
160+
if (!embeddingEndpoint) {
161+
embeddingEndpoint = this._instantiationService.createInstance(EmbeddingEndpoint, modelMetadata);
162+
this._embeddingEndpoints.set(modelId, embeddingEndpoint);
163+
}
164+
return embeddingEndpoint;
165+
}
166+
147167
async getAllCompletionModels(forceRefresh?: boolean): Promise<ICompletionModelInformation[]> {
148168
return this._modelFetcher.getAllCompletionModels(forceRefresh ?? false);
149169
}

src/extension/prompt/vscode-node/settingsEditorSearchServiceImpl.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export class SettingsEditorSearchServiceImpl implements ISettingsEditorSearchSer
3737

3838
let embeddingResult: Embeddings;
3939
try {
40-
embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [query], {}, new TelemetryCorrelationId('SettingsEditorSearchServiceImpl::provideSettingsSearchResults'), token);
40+
embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [query], { endpointType: 'capi' }, new TelemetryCorrelationId('SettingsEditorSearchServiceImpl::provideSettingsSearchResults'), token);
4141
} catch {
4242
if (token.isCancellationRequested) {
4343
progress.report(canceledBundle);

src/extension/prompts/node/panel/newWorkspace/newWorkspace.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ export class NewWorkspacePrompt extends PromptElement<NewWorkspacePromptProps, N
105105
}
106106
else if (instruction.intent === 'Project') {
107107
if (this.props.useTemplates) {
108-
const result = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [instruction.question], {}, undefined);
108+
const result = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [instruction.question], { endpointType: 'capi' }, undefined);
109109
progress.report(new ChatResponseProgressPart(l10n.t('Searching project template index...')));
110110
const similarProjects = await this.projectTemplatesIndex.nClosestValues(result.values[0], 1);
111111
if (similarProjects.length > 0) {

src/extension/prompts/node/panel/vscode.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ export class VscodePrompt extends PromptElement<VscodePromptProps, VscodePromptS
136136
return { settings: [], commands: [], query: userQuery };
137137
}
138138

139-
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [userQuery], {}, undefined);
139+
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [userQuery], { endpointType: 'capi' }, undefined);
140140
if (token.isCancellationRequested) {
141141
return { settings: [], commands: [], query: userQuery };
142142
}

src/extension/test/vscode-node/endpoints.test.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import assert from 'assert';
77
import { SinonSandbox, createSandbox } from 'sinon';
88
import { LanguageModelChat } from 'vscode';
99
import { CHAT_MODEL } from '../../../platform/configuration/common/configurationService';
10-
import { IChatModelInformation, ICompletionModelInformation } from '../../../platform/endpoint/common/endpointProvider';
10+
import { IChatModelInformation, ICompletionModelInformation, IEmbeddingModelInformation } from '../../../platform/endpoint/common/endpointProvider';
1111
import { IModelMetadataFetcher } from '../../../platform/endpoint/node/modelMetadataFetcher';
1212
import { ITestingServicesAccessor } from '../../../platform/test/node/services';
1313
import { TokenizerType } from '../../../util/common/tokenizer';
@@ -43,6 +43,23 @@ class FakeModelMetadataFetcher implements IModelMetadataFetcher {
4343
}
4444
};
4545
}
46+
47+
async getEmbeddingsModel(): Promise<IEmbeddingModelInformation> {
48+
return {
49+
id: 'text-embedding-3-small',
50+
name: 'fake-name',
51+
version: 'fake-version',
52+
model_picker_enabled: false,
53+
is_chat_default: false,
54+
is_chat_fallback: false,
55+
capabilities: {
56+
type: 'embeddings',
57+
tokenizer: TokenizerType.O200K,
58+
family: 'text-embedding-3-small',
59+
limits: { max_inputs: 256 }
60+
}
61+
};
62+
}
4663
}
4764

4865
suite('Endpoint Class Test', function () {

src/platform/embeddings/common/embeddingsComputer.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ export const IEmbeddingsComputer = createServiceIdentifier<IEmbeddingsComputer>(
9191

9292
export type ComputeEmbeddingsOptions = {
9393
readonly inputType?: 'document' | 'query';
94+
readonly endpointType?: 'capi' | 'github';
9495
};
9596

9697
export interface IEmbeddingsComputer {

src/platform/embeddings/common/remoteEmbeddingsComputer.ts

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,29 @@ import { RequestType } from '@vscode/copilot-api';
77
import type { CancellationToken } from 'vscode';
88
import { createRequestHMAC } from '../../../util/common/crypto';
99
import { CallTracker, TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId';
10+
import { Limiter } from '../../../util/vs/base/common/async';
1011
import { env } from '../../../util/vs/base/common/process';
1112
import { generateUuid } from '../../../util/vs/base/common/uuid';
1213
import { IAuthenticationService } from '../../authentication/common/authentication';
1314
import { getGithubMetadataHeaders } from '../../chunking/common/chunkingEndpointClientImpl';
1415
import { ICAPIClientService } from '../../endpoint/common/capiClient';
16+
import { IEndpointProvider } from '../../endpoint/common/endpointProvider';
1517
import { IEnvService } from '../../env/common/envService';
1618
import { logExecTime } from '../../log/common/logExecTime';
1719
import { ILogService } from '../../log/common/logService';
1820
import { IFetcherService } from '../../networking/common/fetcherService';
19-
import { postRequest } from '../../networking/common/networking';
21+
import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking';
2022
import { ITelemetryService } from '../../telemetry/common/telemetry';
21-
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, Embeddings, IEmbeddingsComputer } from './embeddingsComputer';
23+
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer';
2224

25+
interface CAPIEmbeddingResults {
26+
readonly type: 'success';
27+
readonly embeddings: EmbeddingVector[];
28+
}
29+
interface CAPIEmbeddingError {
30+
readonly type: 'failed';
31+
readonly reason: string;
32+
}
2333

2434
export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
2535

@@ -34,6 +44,7 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
3444
@IFetcherService private readonly _fetcherService: IFetcherService,
3545
@ILogService private readonly _logService: ILogService,
3646
@ITelemetryService private readonly _telemetryService: ITelemetryService,
47+
@IEndpointProvider private readonly _endpointProvider: IEndpointProvider,
3748
) { }
3849

3950
public async computeEmbeddings(
@@ -44,6 +55,12 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
4455
cancellationToken?: CancellationToken,
4556
): Promise<Embeddings> {
4657
return logExecTime(this._logService, 'RemoteEmbeddingsComputer::computeEmbeddings', async () => {
58+
59+
if (options?.endpointType === 'capi') {
60+
const embeddings = await this.computeCAPIEmbeddings(inputs, options, cancellationToken);
61+
return embeddings ?? { type: embeddingType, values: [] };
62+
}
63+
4764
const token = (await this._authService.getAnyGitHubSession({ silent: true }))?.accessToken;
4865
if (!token) {
4966
throw new Error('No authentication token available');
@@ -127,4 +144,161 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
127144
return { type: embeddingType, values: embeddingsOut };
128145
});
129146
}
147+
148+
private async computeCAPIEmbeddings(
149+
inputs: readonly string[],
150+
options?: ComputeEmbeddingsOptions,
151+
cancellationToken?: CancellationToken,
152+
) {
153+
const typeInfo = getWellKnownEmbeddingTypeInfo(EmbeddingType.text3small_512);
154+
if (!typeInfo) {
155+
throw new Error(`Embeddings type info not found: ${EmbeddingType.text3small_512}`);
156+
}
157+
const endpoint = await this._endpointProvider.getEmbeddingsEndpoint('text3small');
158+
const batchSize = endpoint.maxBatchSize;
159+
// Open AI seems to allow 1 less than max tokens for the model requests. So if the max tokens is 8192, we can only send 8191 tokens.
160+
const maxTokens = endpoint.modelMaxPromptTokens - 1;
161+
return this.fetchResponseWithBatches(typeInfo, endpoint, inputs, cancellationToken, maxTokens, batchSize);
162+
}
163+
164+
/**
165+
* A recursive helper that drives the public `fetchResponse` function. This allows accepting a batch and supports backing off the endpoint.
166+
* @param inputs The inputs to get embeddings for
167+
* @param cancellationToken A cancellation token to allow cancelling the requests
168+
* @param batchSize The batch size to calculate
169+
* @returns The embeddings
170+
*/
171+
private async fetchResponseWithBatches(
172+
type: EmbeddingTypeInfo,
173+
endpoint: IEmbeddingsEndpoint,
174+
inputs: readonly string[],
175+
cancellationToken: CancellationToken | undefined,
176+
maxTokens: number,
177+
batchSize: number,
178+
parallelism = 1,
179+
): Promise<Embeddings | undefined> {
180+
// First we loop through all inputs and count their token length, if one exceeds max tokens then we fail
181+
for (const input of inputs) {
182+
const inputTokenLength = await endpoint.acquireTokenizer().tokenLength(input);
183+
if (inputTokenLength > maxTokens) {
184+
return undefined;
185+
}
186+
}
187+
188+
let embeddings: EmbeddingVector[] = [];
189+
const promises: Promise<CAPIEmbeddingResults | undefined>[] = [];
190+
const limiter = new Limiter<CAPIEmbeddingResults | undefined>(parallelism);
191+
try {
192+
for (let i = 0; i < inputs.length; i += batchSize) {
193+
const currentBatch = inputs.slice(i, i + batchSize);
194+
promises.push(limiter.queue(async () => {
195+
if (cancellationToken?.isCancellationRequested) {
196+
return;
197+
}
198+
199+
const r = await this.rawEmbeddingsFetchWithTelemetry(type, endpoint, generateUuid(), currentBatch, cancellationToken);
200+
if (r.type === 'failed') {
201+
throw new Error('Embeddings request failed ' + r.reason);
202+
}
203+
return r;
204+
}));
205+
}
206+
207+
embeddings = (await Promise.all(promises)).flatMap(response => response?.embeddings ?? []);
208+
} catch (e) {
209+
return undefined;
210+
} finally {
211+
limiter.dispose();
212+
}
213+
214+
if (cancellationToken?.isCancellationRequested) {
215+
return undefined;
216+
}
217+
218+
// If there are no embeddings, return undefined
219+
if (embeddings.length === 0) {
220+
return undefined;
221+
}
222+
return { type: EmbeddingType.text3small_512, values: embeddings.map((value): Embedding => ({ type: EmbeddingType.text3small_512, value })) };
223+
}
224+
225+
private async rawEmbeddingsFetchWithTelemetry(
226+
type: EmbeddingTypeInfo,
227+
endpoint: IEmbeddingsEndpoint,
228+
requestId: string,
229+
inputs: readonly string[],
230+
cancellationToken: CancellationToken | undefined
231+
) {
232+
const startTime = Date.now();
233+
const rawRequest = await this.rawEmbeddingsFetch(type, endpoint, requestId, inputs, cancellationToken);
234+
if (rawRequest.type === 'failed') {
235+
this._telemetryService.sendMSFTTelemetryErrorEvent('embedding.error', {
236+
type: rawRequest.type,
237+
reason: rawRequest.reason
238+
});
239+
return rawRequest;
240+
}
241+
242+
const tokenizer = endpoint.acquireTokenizer();
243+
const tokenCounts = await Promise.all(inputs.map(input => tokenizer.tokenLength(input)));
244+
const inputTokenCount = tokenCounts.reduce((acc, count) => acc + count, 0);
245+
this._telemetryService.sendMSFTTelemetryEvent('embedding.success', {}, {
246+
batchSize: inputs.length,
247+
inputTokenCount,
248+
timeToComplete: Date.now() - startTime
249+
});
250+
return rawRequest;
251+
}
252+
253+
/**
254+
* The function which actually makes the request to the API and handles failures.
255+
* This is separated out from fetchResponse as fetchResponse does some manipulation to the input and handles errors differently
256+
*/
257+
public async rawEmbeddingsFetch(
258+
type: EmbeddingTypeInfo,
259+
endpoint: IEmbeddingsEndpoint,
260+
requestId: string,
261+
inputs: readonly string[],
262+
cancellationToken: CancellationToken | undefined
263+
): Promise<CAPIEmbeddingResults | CAPIEmbeddingError> {
264+
try {
265+
const token = await this._authService.getCopilotToken();
266+
267+
const body = { input: inputs, model: type.model, dimensions: type.dimensions };
268+
endpoint.interceptBody?.(body);
269+
const response = await postRequest(
270+
this._fetcherService,
271+
this._telemetryService,
272+
this._capiClientService,
273+
endpoint,
274+
token.token,
275+
await createRequestHMAC(env.HMAC_SECRET),
276+
'copilot-panel',
277+
requestId,
278+
body,
279+
undefined,
280+
cancellationToken
281+
);
282+
const jsonResponse = response.status === 200 ? await response.json() : await response.text();
283+
284+
type EmbeddingResponse = {
285+
object: string;
286+
index: number;
287+
embedding: number[];
288+
};
289+
if (response.status === 200 && jsonResponse.data) {
290+
return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) };
291+
} else {
292+
return { type: 'failed', reason: jsonResponse.error };
293+
}
294+
} catch (e) {
295+
let errorMessage = (e as Error)?.message ?? 'Unknown error';
296+
// Timeouts = JSON parse errors because the response is incomplete
297+
if (errorMessage.match(/Unexpected.*JSON/i)) {
298+
errorMessage = 'timeout';
299+
}
300+
return { type: 'failed', reason: errorMessage };
301+
302+
}
303+
}
130304
}

src/platform/embeddings/common/vscodeIndex.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ abstract class RelatedInformationProviderEmbeddingsIndex<V extends { key: string
108108
return [];
109109
}
110110
const startOfEmbeddingRequest = Date.now();
111-
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [query], {}, new TelemetryCorrelationId('RelatedInformationProviderEmbeddingsIndex::provideRelatedInformation'), token);
111+
const embeddingResult = await this.embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [query], { endpointType: 'capi' }, new TelemetryCorrelationId('RelatedInformationProviderEmbeddingsIndex::provideRelatedInformation'), token);
112112
this._logService.debug(`Related Information: Remote similarly request took ${Date.now() - startOfEmbeddingRequest}ms`);
113113
if (token.isCancellationRequested) {
114114
// return an array of 0s the same length as comparisons

0 commit comments

Comments
 (0)