Skip to content

Commit 65b7236

Browse files
committed
feat: Add support for new gemini embedding model. (#5621)
1 parent 10f3e30 commit 65b7236

File tree

6 files changed

+84
-22
lines changed

6 files changed

+84
-22
lines changed

src/core/webview/webviewMessageHandler.ts

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,14 @@ export const webviewMessageHandler = async (
20652065
}
20662066

20672067
case "requestIndexingStatus": {
2068-
const status = provider.codeIndexManager!.getCurrentStatus()
2069-
provider.postMessageToWebview({
2070-
type: "indexingStatusUpdate",
2071-
values: status,
2072-
})
2068+
const manager = provider.codeIndexManager
2069+
if (manager) {
2070+
const status = manager.getCurrentStatus()
2071+
provider.postMessageToWebview({
2072+
type: "indexingStatusUpdate",
2073+
values: status,
2074+
})
2075+
}
20732076
break
20742077
}
20752078
case "requestCodeIndexSecretStatus": {
@@ -2094,8 +2097,8 @@ export const webviewMessageHandler = async (
20942097
}
20952098
case "startIndexing": {
20962099
try {
2097-
const manager = provider.codeIndexManager!
2098-
if (manager.isFeatureEnabled && manager.isFeatureConfigured) {
2100+
const manager = provider.codeIndexManager
2101+
if (manager && manager.isFeatureEnabled && manager.isFeatureConfigured) {
20992102
if (!manager.isInitialized) {
21002103
await manager.initialize(provider.contextProxy)
21012104
}
@@ -2109,9 +2112,11 @@ export const webviewMessageHandler = async (
21092112
}
21102113
case "clearIndexData": {
21112114
try {
2112-
const manager = provider.codeIndexManager!
2113-
await manager.clearIndexData()
2114-
provider.postMessageToWebview({ type: "indexCleared", values: { success: true } })
2115+
const manager = provider.codeIndexManager
2116+
if (manager) {
2117+
await manager.clearIndexData()
2118+
provider.postMessageToWebview({ type: "indexCleared", values: { success: true } })
2119+
}
21152120
} catch (error) {
21162121
provider.log(`Error clearing index data: ${error instanceof Error ? error.message : String(error)}`)
21172122
provider.postMessageToWebview({

src/services/code-index/embedders/__tests__/gemini.spec.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,40 @@ describe("GeminiEmbedder", () => {
114114
await expect(embedder.validateConfiguration()).rejects.toThrow("Validation failed")
115115
})
116116
})
117+
118+
describe("createEmbeddings", () => {
119+
let mockCreateEmbeddings: any
120+
121+
beforeEach(() => {
122+
mockCreateEmbeddings = vitest.fn()
123+
MockedOpenAICompatibleEmbedder.prototype.createEmbeddings = mockCreateEmbeddings
124+
embedder = new GeminiEmbedder("test-api-key")
125+
})
126+
127+
it("should use default model when none is provided", async () => {
128+
// Arrange
129+
const texts = ["text1", "text2"]
130+
mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } })
131+
132+
// Act
133+
await embedder.createEmbeddings(texts)
134+
135+
// Assert
136+
expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "text-embedding-004", undefined)
137+
})
138+
139+
it("should pass model and dimension to the OpenAICompatibleEmbedder", async () => {
140+
// Arrange
141+
const texts = ["text1", "text2"]
142+
const model = "custom-model"
143+
const options = { dimension: 1536 }
144+
mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } })
145+
146+
// Act
147+
await embedder.createEmbeddings(texts, model, options)
148+
149+
// Assert
150+
expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, model, options)
151+
})
152+
})
117153
})

src/services/code-index/embedders/gemini.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ export class GeminiEmbedder implements IEmbedder {
4444
* @param model Optional model identifier (ignored - always uses text-embedding-004)
4545
* @returns Promise resolving to embedding response
4646
*/
47-
async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
47+
async createEmbeddings(
48+
texts: string[],
49+
model?: string,
50+
options?: { dimension?: number },
51+
): Promise<EmbeddingResponse> {
4852
try {
49-
// Always use the fixed Gemini model, ignoring any passed model parameter
50-
return await this.openAICompatibleEmbedder.createEmbeddings(texts, GeminiEmbedder.GEMINI_MODEL)
53+
// Use the provided model or the fixed Gemini model
54+
const modelToUse = model || GeminiEmbedder.GEMINI_MODEL
55+
return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse, options)
5156
} catch (error) {
5257
TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
5358
error: error instanceof Error ? error.message : String(error),

src/services/code-index/embedders/openai-compatible.ts

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
7171
* @param model Optional model identifier
7272
* @returns Promise resolving to embedding response
7373
*/
74-
async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
74+
async createEmbeddings(
75+
texts: string[],
76+
model?: string,
77+
options?: { dimension?: number },
78+
): Promise<EmbeddingResponse> {
7579
const modelToUse = model || this.defaultModelId
7680

7781
// Apply model-specific query prefix if required
@@ -139,7 +143,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
139143
}
140144

141145
if (currentBatch.length > 0) {
142-
const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse)
146+
const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse, options)
143147
allEmbeddings.push(...batchResult.embeddings)
144148
usage.promptTokens += batchResult.usage.promptTokens
145149
usage.totalTokens += batchResult.usage.totalTokens
@@ -181,7 +185,18 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
181185
url: string,
182186
batchTexts: string[],
183187
model: string,
188+
options?: { dimension?: number },
184189
): Promise<OpenAIEmbeddingResponse> {
190+
const body: Record<string, any> = {
191+
input: batchTexts,
192+
model: model,
193+
encoding_format: "base64",
194+
}
195+
196+
if (options?.dimension) {
197+
body.dimensions = options.dimension
198+
}
199+
185200
const response = await fetch(url, {
186201
method: "POST",
187202
headers: {
@@ -191,11 +206,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
191206
"api-key": this.apiKey,
192207
Authorization: `Bearer ${this.apiKey}`,
193208
},
194-
body: JSON.stringify({
195-
input: batchTexts,
196-
model: model,
197-
encoding_format: "base64",
198-
}),
209+
body: JSON.stringify(body),
199210
})
200211

201212
if (!response || !response.ok) {
@@ -234,6 +245,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
234245
private async _embedBatchWithRetries(
235246
batchTexts: string[],
236247
model: string,
248+
options?: { dimension?: number },
237249
): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
238250
// Use cached value for performance
239251
const isFullUrl = this.isFullUrl
@@ -244,7 +256,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
244256

245257
if (isFullUrl) {
246258
// Use direct HTTP request for full endpoint URLs
247-
response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model)
259+
response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model, options)
248260
} else {
249261
// Use OpenAI SDK for base URLs
250262
response = (await this.embeddingsClient.embeddings.create({
@@ -254,6 +266,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
254266
// when processing numeric arrays, which breaks compatibility with models using larger dimensions.
255267
// By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
256268
encoding_format: "base64",
269+
...(options?.dimension && { dimensions: options.dimension }),
257270
})) as OpenAIEmbeddingResponse
258271
}
259272

src/services/code-index/interfaces/embedder.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export interface IEmbedder {
99
* @param model Optional model ID to use for embeddings
1010
* @returns Promise resolving to an EmbeddingResponse
1111
*/
12-
createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse>
12+
createEmbeddings(texts: string[], model?: string, options?: { dimension?: number }): Promise<EmbeddingResponse>
1313

1414
/**
1515
* Validates the embedder configuration by testing connectivity and credentials.

src/shared/embeddingModels.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = {
4848
},
4949
gemini: {
5050
"text-embedding-004": { dimension: 768 },
51+
// ADD: New model with a default dimension.
52+
// The actual dimension will be passed from the configuration at runtime.
53+
"gemini-embedding-exp-03-07": { dimension: 768 },
5154
},
5255
}
5356

0 commit comments

Comments
 (0)