@@ -7,19 +7,29 @@ import { RequestType } from '@vscode/copilot-api';
77import type { CancellationToken } from 'vscode' ;
88import { createRequestHMAC } from '../../../util/common/crypto' ;
99import { CallTracker , TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId' ;
10+ import { Limiter } from '../../../util/vs/base/common/async' ;
1011import { env } from '../../../util/vs/base/common/process' ;
1112import { generateUuid } from '../../../util/vs/base/common/uuid' ;
1213import { IAuthenticationService } from '../../authentication/common/authentication' ;
1314import { getGithubMetadataHeaders } from '../../chunking/common/chunkingEndpointClientImpl' ;
1415import { ICAPIClientService } from '../../endpoint/common/capiClient' ;
16+ import { IEndpointProvider } from '../../endpoint/common/endpointProvider' ;
1517import { IEnvService } from '../../env/common/envService' ;
1618import { logExecTime } from '../../log/common/logExecTime' ;
1719import { ILogService } from '../../log/common/logService' ;
1820import { IFetcherService } from '../../networking/common/fetcherService' ;
19- import { postRequest } from '../../networking/common/networking' ;
21+ import { IEmbeddingsEndpoint , postRequest } from '../../networking/common/networking' ;
2022import { ITelemetryService } from '../../telemetry/common/telemetry' ;
21- import { ComputeEmbeddingsOptions , Embedding , EmbeddingType , Embeddings , IEmbeddingsComputer } from './embeddingsComputer' ;
23+ import { ComputeEmbeddingsOptions , Embedding , EmbeddingType , EmbeddingTypeInfo , EmbeddingVector , Embeddings , IEmbeddingsComputer , getWellKnownEmbeddingTypeInfo } from './embeddingsComputer' ;
2224
25+ interface CAPIEmbeddingResults {
26+ readonly type : 'success' ;
27+ readonly embeddings : EmbeddingVector [ ] ;
28+ }
29+ interface CAPIEmbeddingError {
30+ readonly type : 'failed' ;
31+ readonly reason : string ;
32+ }
2333
2434export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
2535
@@ -34,6 +44,7 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
3444 @IFetcherService private readonly _fetcherService : IFetcherService ,
3545 @ILogService private readonly _logService : ILogService ,
3646 @ITelemetryService private readonly _telemetryService : ITelemetryService ,
47+ @IEndpointProvider private readonly _endpointProvider : IEndpointProvider ,
3748 ) { }
3849
3950 public async computeEmbeddings (
@@ -44,6 +55,12 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
4455 cancellationToken ?: CancellationToken ,
4556 ) : Promise < Embeddings > {
4657 return logExecTime ( this . _logService , 'RemoteEmbeddingsComputer::computeEmbeddings' , async ( ) => {
58+
59+ if ( options ?. endpointType === 'capi' ) {
60+ const embeddings = await this . computeCAPIEmbeddings ( inputs , options , cancellationToken ) ;
61+ return embeddings ?? { type : embeddingType , values : [ ] } ;
62+ }
63+
4764 const token = ( await this . _authService . getAnyGitHubSession ( { silent : true } ) ) ?. accessToken ;
4865 if ( ! token ) {
4966 throw new Error ( 'No authentication token available' ) ;
@@ -127,4 +144,161 @@ export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
127144 return { type : embeddingType , values : embeddingsOut } ;
128145 } ) ;
129146 }
147+
148+ private async computeCAPIEmbeddings (
149+ inputs : readonly string [ ] ,
150+ options ?: ComputeEmbeddingsOptions ,
151+ cancellationToken ?: CancellationToken ,
152+ ) {
153+ const typeInfo = getWellKnownEmbeddingTypeInfo ( EmbeddingType . text3small_512 ) ;
154+ if ( ! typeInfo ) {
155+ throw new Error ( `Embeddings type info not found: ${ EmbeddingType . text3small_512 } ` ) ;
156+ }
157+ const endpoint = await this . _endpointProvider . getEmbeddingsEndpoint ( 'text3small' ) ;
158+ const batchSize = endpoint . maxBatchSize ;
159+ // Open AI seems to allow 1 less than max tokens for the model requests. So if the max tokens is 8192, we can only send 8191 tokens.
160+ const maxTokens = endpoint . modelMaxPromptTokens - 1 ;
161+ return this . fetchResponseWithBatches ( typeInfo , endpoint , inputs , cancellationToken , maxTokens , batchSize ) ;
162+ }
163+
164+ /**
165+ * A recursive helper that drives the public `fetchResponse` function. This allows accepting a batch and supports backing off the endpoint.
166+ * @param inputs The inputs to get embeddings for
167+ * @param cancellationToken A cancellation token to allow cancelling the requests
168+ * @param batchSize The batch size to calculate
169+ * @returns The embeddings
170+ */
171+ private async fetchResponseWithBatches (
172+ type : EmbeddingTypeInfo ,
173+ endpoint : IEmbeddingsEndpoint ,
174+ inputs : readonly string [ ] ,
175+ cancellationToken : CancellationToken | undefined ,
176+ maxTokens : number ,
177+ batchSize : number ,
178+ parallelism = 1 ,
179+ ) : Promise < Embeddings | undefined > {
180+ // First we loop through all inputs and count their token length, if one exceeds max tokens then we fail
181+ for ( const input of inputs ) {
182+ const inputTokenLength = await endpoint . acquireTokenizer ( ) . tokenLength ( input ) ;
183+ if ( inputTokenLength > maxTokens ) {
184+ return undefined ;
185+ }
186+ }
187+
188+ let embeddings : EmbeddingVector [ ] = [ ] ;
189+ const promises : Promise < CAPIEmbeddingResults | undefined > [ ] = [ ] ;
190+ const limiter = new Limiter < CAPIEmbeddingResults | undefined > ( parallelism ) ;
191+ try {
192+ for ( let i = 0 ; i < inputs . length ; i += batchSize ) {
193+ const currentBatch = inputs . slice ( i , i + batchSize ) ;
194+ promises . push ( limiter . queue ( async ( ) => {
195+ if ( cancellationToken ?. isCancellationRequested ) {
196+ return ;
197+ }
198+
199+ const r = await this . rawEmbeddingsFetchWithTelemetry ( type , endpoint , generateUuid ( ) , currentBatch , cancellationToken ) ;
200+ if ( r . type === 'failed' ) {
201+ throw new Error ( 'Embeddings request failed ' + r . reason ) ;
202+ }
203+ return r ;
204+ } ) ) ;
205+ }
206+
207+ embeddings = ( await Promise . all ( promises ) ) . flatMap ( response => response ?. embeddings ?? [ ] ) ;
208+ } catch ( e ) {
209+ return undefined ;
210+ } finally {
211+ limiter . dispose ( ) ;
212+ }
213+
214+ if ( cancellationToken ?. isCancellationRequested ) {
215+ return undefined ;
216+ }
217+
218+ // If there are no embeddings, return undefined
219+ if ( embeddings . length === 0 ) {
220+ return undefined ;
221+ }
222+ return { type : EmbeddingType . text3small_512 , values : embeddings . map ( ( value ) : Embedding => ( { type : EmbeddingType . text3small_512 , value } ) ) } ;
223+ }
224+
225+ private async rawEmbeddingsFetchWithTelemetry (
226+ type : EmbeddingTypeInfo ,
227+ endpoint : IEmbeddingsEndpoint ,
228+ requestId : string ,
229+ inputs : readonly string [ ] ,
230+ cancellationToken : CancellationToken | undefined
231+ ) {
232+ const startTime = Date . now ( ) ;
233+ const rawRequest = await this . rawEmbeddingsFetch ( type , endpoint , requestId , inputs , cancellationToken ) ;
234+ if ( rawRequest . type === 'failed' ) {
235+ this . _telemetryService . sendMSFTTelemetryErrorEvent ( 'embedding.error' , {
236+ type : rawRequest . type ,
237+ reason : rawRequest . reason
238+ } ) ;
239+ return rawRequest ;
240+ }
241+
242+ const tokenizer = endpoint . acquireTokenizer ( ) ;
243+ const tokenCounts = await Promise . all ( inputs . map ( input => tokenizer . tokenLength ( input ) ) ) ;
244+ const inputTokenCount = tokenCounts . reduce ( ( acc , count ) => acc + count , 0 ) ;
245+ this . _telemetryService . sendMSFTTelemetryEvent ( 'embedding.success' , { } , {
246+ batchSize : inputs . length ,
247+ inputTokenCount,
248+ timeToComplete : Date . now ( ) - startTime
249+ } ) ;
250+ return rawRequest ;
251+ }
252+
253+ /**
254+ * The function which actually makes the request to the API and handles failures.
255+ * This is separated out from fetchResponse as fetchResponse does some manipulation to the input and handles errors differently
256+ */
257+ public async rawEmbeddingsFetch (
258+ type : EmbeddingTypeInfo ,
259+ endpoint : IEmbeddingsEndpoint ,
260+ requestId : string ,
261+ inputs : readonly string [ ] ,
262+ cancellationToken : CancellationToken | undefined
263+ ) : Promise < CAPIEmbeddingResults | CAPIEmbeddingError > {
264+ try {
265+ const token = await this . _authService . getCopilotToken ( ) ;
266+
267+ const body = { input : inputs , model : type . model , dimensions : type . dimensions } ;
268+ endpoint . interceptBody ?.( body ) ;
269+ const response = await postRequest (
270+ this . _fetcherService ,
271+ this . _telemetryService ,
272+ this . _capiClientService ,
273+ endpoint ,
274+ token . token ,
275+ await createRequestHMAC ( env . HMAC_SECRET ) ,
276+ 'copilot-panel' ,
277+ requestId ,
278+ body ,
279+ undefined ,
280+ cancellationToken
281+ ) ;
282+ const jsonResponse = response . status === 200 ? await response . json ( ) : await response . text ( ) ;
283+
284+ type EmbeddingResponse = {
285+ object : string ;
286+ index : number ;
287+ embedding : number [ ] ;
288+ } ;
289+ if ( response . status === 200 && jsonResponse . data ) {
290+ return { type : 'success' , embeddings : jsonResponse . data . map ( ( d : EmbeddingResponse ) => d . embedding ) } ;
291+ } else {
292+ return { type : 'failed' , reason : jsonResponse . error } ;
293+ }
294+ } catch ( e ) {
295+ let errorMessage = ( e as Error ) ?. message ?? 'Unknown error' ;
296+ // Timeouts = JSON parse errors because the response is incomplete
297+ if ( errorMessage . match ( / U n e x p e c t e d .* J S O N / i) ) {
298+ errorMessage = 'timeout' ;
299+ }
300+ return { type : 'failed' , reason : errorMessage } ;
301+
302+ }
303+ }
130304}
0 commit comments