diff --git a/src/extension/extension/vscode-node/services.ts b/src/extension/extension/vscode-node/services.ts index 04fab1d78..cf22a331d 100644 --- a/src/extension/extension/vscode-node/services.ts +++ b/src/extension/extension/vscode-node/services.ts @@ -103,6 +103,7 @@ import { IWorkspaceListenerService } from '../../workspaceRecorder/common/worksp import { WorkspacListenerService } from '../../workspaceRecorder/vscode-node/workspaceListenerService'; import { registerServices as registerCommonServices } from '../vscode/services'; import { NativeEnvServiceImpl } from '../../../platform/env/vscode-node/nativeEnvServiceImpl'; +import { GithubAvailableEmbeddingTypesService, IGithubAvailableEmbeddingTypesService } from '../../../platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes'; // ########################################################################################### // ### ### @@ -195,6 +196,7 @@ export function registerServices(builder: IInstantiationServiceBuilder, extensio builder.define(IWorkspaceListenerService, new SyncDescriptor(WorkspacListenerService)); builder.define(ICodeSearchAuthenticationService, new SyncDescriptor(VsCodeCodeSearchAuthenticationService)); builder.define(ITodoListContextProvider, new SyncDescriptor(TodoListContextProvider)); + builder.define(IGithubAvailableEmbeddingTypesService, new SyncDescriptor(GithubAvailableEmbeddingTypesService)); } function setupMSFTExperimentationService(builder: IInstantiationServiceBuilder, extensionContext: ExtensionContext) { diff --git a/src/extension/test/node/services.ts b/src/extension/test/node/services.ts index 80ba93c70..0f1fc9be5 100644 --- a/src/extension/test/node/services.ts +++ b/src/extension/test/node/services.ts @@ -30,6 +30,7 @@ import { SimulationAlternativeNotebookContentService, SimulationNotebookService, import { NullTestProvider } from '../../../platform/testing/common/nullTestProvider'; import { TestLogService } from '../../../platform/testing/common/testLogService'; import { ITestProvider } from '../../../platform/testing/common/testProvider'; +import { IGithubAvailableEmbeddingTypesService, MockGithubAvailableEmbeddingTypesService } from '../../../platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes'; import { IWorkspaceChunkSearchService, NullWorkspaceChunkSearchService } from '../../../platform/workspaceChunkSearch/node/workspaceChunkSearchService'; import { DisposableStore } from '../../../util/vs/base/common/lifecycle'; import { SyncDescriptor } from '../../../util/vs/platform/instantiation/common/descriptors'; @@ -102,5 +103,6 @@ export function createExtensionUnitTestingServices(disposables: Pick { public static readonly toolName = ToolName.GithubRepo; - private readonly _availableEmbeddingTypesManager = new Lazy(() => this._instantiationService.createInstance(GithubAvailableEmbeddingTypesManager)); constructor( @IRunCommandExecutionService _commandService: IRunCommandExecutionService, @IInstantiationService private readonly _instantiationService: IInstantiationService, @IGithubCodeSearchService private readonly _githubCodeSearch: IGithubCodeSearchService, + @IGithubAvailableEmbeddingTypesService private readonly _availableEmbeddingTypesManager: GithubAvailableEmbeddingTypesService, @ITelemetryService private readonly _telemetryService: ITelemetryService, ) { } @@ -57,7 +56,7 @@ export class GithubRepoTool implements ICopilotTool { throw new Error('Invalid input. Could not parse repo'); } - const embeddingType = await this._availableEmbeddingTypesManager.value.getPreferredType(false); + const embeddingType = await this._availableEmbeddingTypesManager.getPreferredType(false); if (!embeddingType) { throw new Error('No embedding models available'); } diff --git a/src/platform/embeddings/common/embeddingsComputer.ts b/src/platform/embeddings/common/embeddingsComputer.ts index 07c448e39..ca428554c 100644 --- a/src/platform/embeddings/common/embeddingsComputer.ts +++ b/src/platform/embeddings/common/embeddingsComputer.ts @@ -89,8 +89,10 @@ export interface EmbeddingDistance { export const IEmbeddingsComputer = createServiceIdentifier('IEmbeddingsComputer'); +export type EmbeddingInputType = 'document' | 'query'; + export type ComputeEmbeddingsOptions = { - readonly inputType?: 'document' | 'query'; + readonly inputType?: EmbeddingInputType; }; export interface IEmbeddingsComputer { diff --git a/src/platform/urlChunkSearch/node/urlChunkEmbeddingsIndex.ts b/src/platform/urlChunkSearch/node/urlChunkEmbeddingsIndex.ts index 5def67076..ecaef5cde 100644 --- a/src/platform/urlChunkSearch/node/urlChunkEmbeddingsIndex.ts +++ b/src/platform/urlChunkSearch/node/urlChunkEmbeddingsIndex.ts @@ -13,8 +13,9 @@ import { URI } from '../../../util/vs/base/common/uri'; import { IAuthenticationService } from '../../authentication/common/authentication'; import { FileChunkAndScore, FileChunkWithEmbedding } from '../../chunking/common/chunk'; import { ChunkableContent, ComputeBatchInfo, EmbeddingsComputeQos, IChunkingEndpointClient } from '../../chunking/common/chunkingEndpointClient'; -import { distance, Embedding, EmbeddingType, IEmbeddingsComputer } from '../../embeddings/common/embeddingsComputer'; +import { distance, Embedding, EmbeddingInputType, EmbeddingType, IEmbeddingsComputer } from '../../embeddings/common/embeddingsComputer'; import { ILogService } from '../../log/common/logService'; +import { IGithubAvailableEmbeddingTypesService } from '../../workspaceChunkSearch/common/githubAvailableEmbeddingTypes'; /** * The maximum content length to sent to the chunking endpoint. @@ -51,6 +52,7 @@ export class UrlChunkEmbeddingsIndex extends Disposable { @ILogService private readonly _logService: ILogService, @IEmbeddingsComputer private readonly _embeddingsComputer: IEmbeddingsComputer, @IChunkingEndpointClient private readonly _chunkingEndpointClient: IChunkingEndpointClient, + @IGithubAvailableEmbeddingTypesService private readonly _availableEmbeddingTypesService: IGithubAvailableEmbeddingTypesService, ) { super(); } @@ -60,20 +62,25 @@ export class UrlChunkEmbeddingsIndex extends Disposable { query: string, token: CancellationToken, ): Promise { + const embeddingType = await raceCancellationError(this._availableEmbeddingTypesService.getPreferredType(/*silent*/ false), token); + if (!embeddingType) { + throw new Error('No embedding types available'); + } + const [queryEmbedding, fileChunksAndEmbeddings] = await raceCancellationError(Promise.all([ - this.computeEmbeddings(query, token), - this.getEmbeddingsForFiles(files.map(file => new UrlContent(file.uri, file.content)), EmbeddingsComputeQos.Batch, token) + this.computeEmbeddings(embeddingType, query, 'query', token), + this.getEmbeddingsForFiles(embeddingType, files.map(file => new UrlContent(file.uri, file.content)), EmbeddingsComputeQos.Batch, token) ]), token); return this.computeChunkScores(fileChunksAndEmbeddings, queryEmbedding); } - private async computeEmbeddings(str: string, token: CancellationToken): Promise { - const embeddings = await this._embeddingsComputer.computeEmbeddings(EmbeddingType.text3small_512, [str], {}, new TelemetryCorrelationId('UrlChunkEmbeddingsIndex::computeEmbeddings'), token); + private async computeEmbeddings(embeddingType: EmbeddingType, str: string, inputType: EmbeddingInputType, token: CancellationToken): Promise { + const embeddings = await this._embeddingsComputer.computeEmbeddings(embeddingType, [str], { inputType }, new TelemetryCorrelationId('UrlChunkEmbeddingsIndex::computeEmbeddings'), token); return embeddings.values[0]; } - private async getEmbeddingsForFiles(files: readonly UrlContent[], qos: EmbeddingsComputeQos, token: CancellationToken): Promise<(readonly FileChunkWithEmbedding[])[]> { + private async getEmbeddingsForFiles(embeddingType: EmbeddingType, files: readonly UrlContent[], qos: EmbeddingsComputeQos, token: CancellationToken): Promise<(readonly FileChunkWithEmbedding[])[]> { if (!files.length) { return []; } @@ -88,7 +95,7 @@ export class UrlChunkEmbeddingsIndex extends Disposable { } const result = await Promise.all(files.map(async file => { - const result = await this.getChunksAndEmbeddings(authToken, file, batchInfo, qos, token); + const result = await this.getChunksAndEmbeddings(authToken, embeddingType, file, batchInfo, qos, token); if (!result) { return []; } @@ -107,13 +114,13 @@ export class UrlChunkEmbeddingsIndex extends Disposable { ); } - private async getChunksAndEmbeddings(authToken: string, content: UrlContent, batchInfo: ComputeBatchInfo, qos: EmbeddingsComputeQos, token: CancellationToken): Promise { + private async getChunksAndEmbeddings(authToken: string, embeddingType: EmbeddingType, content: UrlContent, batchInfo: ComputeBatchInfo, qos: EmbeddingsComputeQos, token: CancellationToken): Promise { const existing = await raceCancellationError(this._cache.get(content), token); if (existing) { return existing; } - const chunksAndEmbeddings = await raceCancellationError(this._chunkingEndpointClient.computeChunksAndEmbeddings(authToken, EmbeddingType.text3small_512, content, batchInfo, qos, new Map(), new CallTracker('UrlChunkEmbeddingsIndex::getChunksAndEmbeddings'), token), token); + const chunksAndEmbeddings = await raceCancellationError(this._chunkingEndpointClient.computeChunksAndEmbeddings(authToken, embeddingType, content, batchInfo, qos, new Map(), new CallTracker('UrlChunkEmbeddingsIndex::getChunksAndEmbeddings'), token), token); if (chunksAndEmbeddings) { this._cache.set(content, chunksAndEmbeddings); } diff --git a/src/platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes.ts b/src/platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes.ts index 9ba99a057..d73cf1f9b 100644 --- a/src/platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes.ts +++ b/src/platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes.ts @@ -6,6 +6,7 @@ import { RequestType } from '@vscode/copilot-api'; import { createRequestHMAC } from '../../../util/common/crypto'; import { Result } from '../../../util/common/result'; +import { createServiceIdentifier } from '../../../util/common/services'; import { CallTracker } from '../../../util/common/telemetryCorrelationId'; import { env } from '../../../util/vs/base/common/process'; import { generateUuid } from '../../../util/vs/base/common/uuid'; @@ -35,7 +36,22 @@ type GetAvailableTypesError = type GetAvailableTypesResult = Result; -export class GithubAvailableEmbeddingTypesManager { +export const IGithubAvailableEmbeddingTypesService = createServiceIdentifier('IGithubAvailableEmbeddingTypesService'); + +export interface IGithubAvailableEmbeddingTypesService { + readonly _serviceBrand: undefined; + + /** + * Gets the preferred embedding type based on available types and user configuration. + * @param silent Whether to silently handle authentication errors + * @returns The preferred embedding type or undefined if none available + */ + getPreferredType(silent: boolean): Promise; +} + +export class GithubAvailableEmbeddingTypesService implements IGithubAvailableEmbeddingTypesService { + + readonly _serviceBrand: undefined; private _cached?: Promise; @@ -213,3 +229,12 @@ export class GithubAvailableEmbeddingTypesManager { return all.primary.at(0) ?? all.deprecated.at(0); } } + + +export class MockGithubAvailableEmbeddingTypesService implements IGithubAvailableEmbeddingTypesService { + declare readonly _serviceBrand: undefined; + + async getPreferredType(_silent: boolean): Promise { + return EmbeddingType.metis_1024_I16_Binary; + } +} diff --git a/src/platform/workspaceChunkSearch/node/workspaceChunkSearchService.ts b/src/platform/workspaceChunkSearch/node/workspaceChunkSearchService.ts index 145ccae34..1986fe598 100644 --- a/src/platform/workspaceChunkSearch/node/workspaceChunkSearchService.ts +++ b/src/platform/workspaceChunkSearch/node/workspaceChunkSearchService.ts @@ -34,7 +34,7 @@ import { ISimulationTestContext } from '../../simulationTestContext/common/simul import { IExperimentationService } from '../../telemetry/common/nullExperimentationService'; import { ITelemetryService } from '../../telemetry/common/telemetry'; import { getWorkspaceFileDisplayPath, IWorkspaceService } from '../../workspace/common/workspaceService'; -import { GithubAvailableEmbeddingTypesManager } from '../common/githubAvailableEmbeddingTypes'; +import { IGithubAvailableEmbeddingTypesService } from '../common/githubAvailableEmbeddingTypes'; import { IWorkspaceChunkSearchStrategy, StrategySearchResult, StrategySearchSizing, WorkspaceChunkQuery, WorkspaceChunkQueryWithEmbeddings, WorkspaceChunkSearchOptions, WorkspaceChunkSearchStrategyId, WorkspaceSearchAlert } from '../common/workspaceChunkSearch'; import { CodeSearchChunkSearch, CodeSearchRemoteIndexState } from './codeSearchChunkSearch'; import { EmbeddingsChunkSearch, LocalEmbeddingsIndexState, LocalEmbeddingsIndexStatus } from './embeddingsChunkSearch'; @@ -115,17 +115,15 @@ export class WorkspaceChunkSearchService extends Disposable implements IWorkspac readonly onDidChangeIndexState = this._onDidChangeIndexState.event; private _impl: WorkspaceChunkSearchServiceImpl | undefined; - private readonly _availableEmbeddingTypes: GithubAvailableEmbeddingTypesManager; constructor( @IInstantiationService private readonly _instantiationService: IInstantiationService, @IAuthenticationService private readonly _authenticationService: IAuthenticationService, + @IGithubAvailableEmbeddingTypesService private readonly _availableEmbeddingTypes: IGithubAvailableEmbeddingTypesService, @ILogService private readonly _logService: ILogService, ) { super(); - this._availableEmbeddingTypes = _instantiationService.createInstance(GithubAvailableEmbeddingTypesManager); - this.tryInit(true); } diff --git a/test/base/simulationContext.ts b/test/base/simulationContext.ts index b26d9d517..7ed2fd6b4 100644 --- a/test/base/simulationContext.ts +++ b/test/base/simulationContext.ts @@ -40,6 +40,7 @@ import { SimulationReviewService } from '../../src/platform/test/node/simulation import { NullTestProvider } from '../../src/platform/testing/common/nullTestProvider'; import { ITestProvider } from '../../src/platform/testing/common/testProvider'; import { ITokenizerProvider, TokenizerProvider } from '../../src/platform/tokenizer/node/tokenizer'; +import { GithubAvailableEmbeddingTypesService, IGithubAvailableEmbeddingTypesService } from '../../src/platform/workspaceChunkSearch/common/githubAvailableEmbeddingTypes'; import { IWorkspaceChunkSearchService, WorkspaceChunkSearchService } from '../../src/platform/workspaceChunkSearch/node/workspaceChunkSearchService'; import { IWorkspaceFileIndex, WorkspaceFileIndex } from '../../src/platform/workspaceChunkSearch/node/workspaceFileIndex'; import { createServiceIdentifier } from '../../src/util/common/services'; @@ -292,6 +293,7 @@ export async function createSimulationAccessor( testingServiceCollection.define(IGitExtensionService, new SyncDescriptor(NullGitExtensionService)); testingServiceCollection.define(IReleaseNotesService, new SyncDescriptor(ReleaseNotesService)); testingServiceCollection.define(IWorkspaceFileIndex, new SyncDescriptor(WorkspaceFileIndex)); + testingServiceCollection.define(IGithubAvailableEmbeddingTypesService, new SyncDescriptor(GithubAvailableEmbeddingTypesService)); if (opts.useExperimentalCodeSearchService) { testingServiceCollection.define(IWorkspaceChunkSearchService, new SyncDescriptor(SimulationCodeSearchChunkSearchService, []));