Skip to content

Commit 2dfb2c4

Browse files
committed
feat: Enhance Gemini embedder with configurable dimensions and validation for embedding models
1 parent 761b2ea commit 2dfb2c4

File tree

5 files changed

+71
-12
lines changed

5 files changed

+71
-12
lines changed

src/services/code-index/embedders/__tests__/gemini.spec.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,40 @@ describe("GeminiEmbedder", () => {
190190
await expect(embedder.validateConfiguration()).rejects.toThrow("Validation failed")
191191
})
192192
})
193+
194+
describe("createEmbeddings", () => {
195+
let mockCreateEmbeddings: any
196+
197+
beforeEach(() => {
198+
mockCreateEmbeddings = vitest.fn()
199+
MockedOpenAICompatibleEmbedder.prototype.createEmbeddings = mockCreateEmbeddings
200+
embedder = new GeminiEmbedder("test-api-key")
201+
})
202+
203+
it("should use default model when none is provided", async () => {
204+
// Arrange
205+
const texts = ["text1", "text2"]
206+
mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } })
207+
208+
// Act
209+
await embedder.createEmbeddings(texts)
210+
211+
// Assert
212+
expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "text-embedding-004", undefined)
213+
})
214+
215+
it("should pass model and dimension to the OpenAICompatibleEmbedder", async () => {
216+
// Arrange
217+
const texts = ["text1", "text2"]
218+
const model = "custom-model"
219+
const options = { dimension: 1536 }
220+
mockCreateEmbeddings.mockResolvedValue({ embeddings: [], usage: { promptTokens: 0, totalTokens: 0 } })
221+
222+
// Act
223+
await embedder.createEmbeddings(texts, model, options)
224+
225+
// Assert
226+
expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, model, options)
227+
})
228+
})
193229
})

src/services/code-index/embedders/gemini.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,15 @@ export class GeminiEmbedder implements IEmbedder {
4747
* @param model Optional model identifier (uses constructor model if not provided)
4848
* @returns Promise resolving to embedding response
4949
*/
50-
async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
50+
async createEmbeddings(
51+
texts: string[],
52+
model?: string,
53+
options?: { dimension?: number },
54+
): Promise<EmbeddingResponse> {
5155
try {
5256
// Use the provided model or fall back to the instance's model
5357
const modelToUse = model || this.modelId
54-
return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse)
58+
return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse, options)
5559
} catch (error) {
5660
TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
5761
error: error instanceof Error ? error.message : String(error),

src/services/code-index/embedders/openai-compatible.ts

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
7171
* @param model Optional model identifier
7272
* @returns Promise resolving to embedding response
7373
*/
74-
async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
74+
async createEmbeddings(
75+
texts: string[],
76+
model?: string,
77+
options?: { dimension?: number },
78+
): Promise<EmbeddingResponse> {
7579
const modelToUse = model || this.defaultModelId
7680

7781
// Apply model-specific query prefix if required
@@ -139,7 +143,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
139143
}
140144

141145
if (currentBatch.length > 0) {
142-
const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse)
146+
const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse, options)
143147
allEmbeddings.push(...batchResult.embeddings)
144148
usage.promptTokens += batchResult.usage.promptTokens
145149
usage.totalTokens += batchResult.usage.totalTokens
@@ -181,7 +185,18 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
181185
url: string,
182186
batchTexts: string[],
183187
model: string,
188+
options?: { dimension?: number },
184189
): Promise<OpenAIEmbeddingResponse> {
190+
const body: Record<string, any> = {
191+
input: batchTexts,
192+
model: model,
193+
encoding_format: "base64",
194+
}
195+
196+
if (options?.dimension) {
197+
body.dimensions = options.dimension
198+
}
199+
185200
const response = await fetch(url, {
186201
method: "POST",
187202
headers: {
@@ -191,11 +206,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
191206
"api-key": this.apiKey,
192207
Authorization: `Bearer ${this.apiKey}`,
193208
},
194-
body: JSON.stringify({
195-
input: batchTexts,
196-
model: model,
197-
encoding_format: "base64",
198-
}),
209+
body: JSON.stringify(body),
199210
})
200211

201212
if (!response || !response.ok) {
@@ -234,6 +245,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
234245
private async _embedBatchWithRetries(
235246
batchTexts: string[],
236247
model: string,
248+
options?: { dimension?: number },
237249
): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
238250
// Use cached value for performance
239251
const isFullUrl = this.isFullUrl
@@ -244,7 +256,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
244256

245257
if (isFullUrl) {
246258
// Use direct HTTP request for full endpoint URLs
247-
response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model)
259+
response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model, options)
248260
} else {
249261
// Use OpenAI SDK for base URLs
250262
response = (await this.embeddingsClient.embeddings.create({
@@ -254,6 +266,7 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
254266
// when processing numeric arrays, which breaks compatibility with models using larger dimensions.
255267
// By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
256268
encoding_format: "base64",
269+
...(options?.dimension && { dimensions: options.dimension }),
257270
})) as OpenAIEmbeddingResponse
258271
}
259272

src/services/code-index/interfaces/embedder.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export interface IEmbedder {
99
* @param model Optional model ID to use for embeddings
1010
* @returns Promise resolving to an EmbeddingResponse
1111
*/
12-
createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse>
12+
createEmbeddings(texts: string[], model?: string, options?: { dimension?: number }): Promise<EmbeddingResponse>
1313

1414
/**
1515
* Validates the embedder configuration by testing connectivity and credentials.

src/shared/embeddingModels.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = {
5252
},
5353
gemini: {
5454
"text-embedding-004": { dimension: 768 },
55-
"gemini-embedding-001": { dimension: 3072, scoreThreshold: 0.4 },
55+
"gemini-embedding-001": {
56+
dimension: 3072, // Fallback, but defaultDimension is preferred
57+
minDimension: 128,
58+
maxDimension: 3072,
59+
defaultDimension: 3072,
60+
scoreThreshold: 0.4,
61+
},
5662
},
5763
}
5864

0 commit comments

Comments
 (0)