Skip to content

Commit 2149430

Browse files
committed
feat(code-index): make taskType optional for Gemini embeddings depending on model support
1 parent 14117b7 commit 2149430

File tree

4 files changed

+93
-62
lines changed

4 files changed

+93
-62
lines changed

src/services/code-index/config-manager.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { ContextProxy } from "../../core/config/ContextProxy"
33
import { EmbedderProvider } from "./interfaces/manager"
44
import { CodeIndexConfig, PreviousConfigSnapshot } from "./interfaces/config"
55
import { SEARCH_MIN_SCORE } from "./constants"
6-
import { getDefaultModelId, getModelDimension } from "../../shared/embeddingModels"
6+
import { getDefaultModelId, getModelDimension, EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels"
77

88
// Define a type for the raw config state from globalState
99
interface RawCodebaseIndexConfigState {
@@ -52,7 +52,7 @@ export class CodeIndexConfigManager {
5252
codebaseIndexEmbedderProvider: "openai",
5353
codebaseIndexEmbedderBaseUrl: "",
5454
codebaseIndexEmbedderModelId: "",
55-
geminiEmbeddingTaskType: "CODE_RETRIEVAL_QUERY",
55+
geminiEmbeddingTaskType: undefined,
5656
geminiEmbeddingDimension: undefined,
5757
}) as RawCodebaseIndexConfigState // Cast to our defined raw state type
5858

@@ -112,7 +112,7 @@ export class CodeIndexConfigManager {
112112

113113
this.geminiOptions = {
114114
geminiApiKey,
115-
geminiEmbeddingTaskType: geminiEmbeddingTaskType || "CODE_RETRIEVAL_QUERY",
115+
geminiEmbeddingTaskType: geminiEmbeddingTaskType,
116116
apiModelId: this.modelId,
117117
geminiEmbeddingDimension,
118118
rateLimitSeconds,
@@ -205,10 +205,19 @@ export class CodeIndexConfigManager {
205205
if (this.embedderProvider === "gemini") {
206206
// Gemini requires an API key and Qdrant URL
207207
const geminiApiKey = this.geminiOptions?.geminiApiKey
208+
const modelId = this.modelId || getDefaultModelId("gemini")
209+
const qdrantUrl = this.qdrantUrl
210+
211+
// Check if the model supports taskType
212+
const geminiProfiles = EMBEDDING_MODEL_PROFILES.gemini || {}
213+
const modelProfile = geminiProfiles[modelId]
214+
const supportsTaskType = modelProfile?.supportsTaskType || false
215+
216+
// Only require taskType if the model supports it
208217
const geminiEmbeddingTaskType = this.geminiOptions?.geminiEmbeddingTaskType
218+
const taskTypeValid = !supportsTaskType || (supportsTaskType && !!geminiEmbeddingTaskType)
209219

210-
const qdrantUrl = this.qdrantUrl
211-
const isConfigured = !!(geminiApiKey && geminiEmbeddingTaskType && qdrantUrl)
220+
const isConfigured = !!(geminiApiKey && taskTypeValid && qdrantUrl)
212221
return isConfigured
213222
}
214223
return false // Should not happen if embedderProvider is always set correctly

src/services/code-index/embedders/gemini.ts

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import { RetryHandler } from "../../../utils/retry-handler"
1111
*/
1212
export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder {
1313
private readonly defaultModelId: string
14-
private readonly defaultTaskType: string
14+
private readonly defaultTaskType?: string
1515
private readonly rateLimiter: SlidingWindowRateLimiter
1616
private readonly retryHandler: RetryHandler
1717

@@ -22,7 +22,7 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
2222
constructor(options: ApiHandlerOptions) {
2323
super(options)
2424
this.defaultModelId = options.apiModelId || "gemini-embedding-exp-03-07"
25-
this.defaultTaskType = options.geminiEmbeddingTaskType || "CODE_RETRIEVAL_QUERY"
25+
this.defaultTaskType = options.geminiEmbeddingTaskType
2626

2727
// Calculate rate limit parameters based on rateLimitSeconds or default
2828
const rateLimitSeconds = options.rateLimitSeconds || GEMINI_RATE_LIMIT_DELAY_MS / 1000
@@ -49,7 +49,7 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
4949
async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
5050
try {
5151
const modelId = model || this.defaultModelId
52-
const result = await this.embedWithTokenLimit(texts, modelId, this.defaultTaskType)
52+
const result = await this.embedWithTokenLimit(texts, modelId, this.defaultTaskType || "")
5353
return {
5454
embeddings: result.embeddings,
5555
}
@@ -72,7 +72,7 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
7272
private async _processAndAggregateBatch(
7373
batch: string[],
7474
model: string,
75-
taskType: string,
75+
taskType: string = "",
7676
allEmbeddings: number[][],
7777
aggregatedUsage: { promptTokens: number; totalTokens: number },
7878
isFinalBatch: boolean = false,
@@ -104,7 +104,7 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
104104
private async embedWithTokenLimit(
105105
texts: string[],
106106
model: string,
107-
taskType: string,
107+
taskType: string = "",
108108
): Promise<{
109109
embeddings: number[][]
110110
usage: { promptTokens: number; totalTokens: number }
@@ -161,20 +161,29 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
161161
*
162162
* @param batchTexts Array of texts to embed
163163
* @param modelId Model identifier to use for the API call
164-
* @param taskType The task type for the embedding
164+
* @param taskType The task type for the embedding (only used if the model supports it)
165165
* @returns Promise resolving to embeddings and usage statistics
166166
*/
167167
private async _callGeminiEmbeddingApi(
168168
batchTexts: string[],
169169
modelId: string,
170170
taskType: string,
171171
): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
172+
// Check if the model supports taskType
173+
const geminiProfiles = EMBEDDING_MODEL_PROFILES.gemini || {}
174+
const modelProfile = geminiProfiles[modelId]
175+
const supportsTaskType = modelProfile?.supportsTaskType || false
176+
177+
// Only include taskType in the config if the model supports it
178+
const config: { taskType?: string } = {}
179+
if (supportsTaskType) {
180+
config.taskType = taskType
181+
}
182+
172183
const response = await this.client.models.embedContent({
173184
model: modelId,
174185
contents: batchTexts,
175-
config: {
176-
taskType,
177-
},
186+
config,
178187
})
179188

180189
if (!response.embeddings) {
@@ -204,7 +213,7 @@ export class CodeIndexGeminiEmbedder extends GeminiHandler implements IEmbedder
204213
private async _embedBatch(
205214
batchTexts: string[],
206215
model: string,
207-
taskType: string,
216+
taskType: string = "",
208217
): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
209218
const modelId = model || this.defaultModelId
210219

src/shared/embeddingModels.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ export interface EmbeddingModelProfile {
1414
* Optional maximum input tokens for the model.
1515
*/
1616
maxInputTokens?: number
17+
/**
18+
* Whether the model supports the taskType parameter.
19+
* Only some Gemini models support this parameter.
20+
*/
21+
supportsTaskType?: boolean
1722
// Add other model-specific properties if needed, e.g., context window size
1823
}
1924

@@ -43,9 +48,14 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = {
4348
"text-embedding-ada-002": { dimension: 1536 },
4449
},
4550
gemini: {
46-
"gemini-embedding-exp-03-07": { dimension: 3072, supportDimensions: [3072, 1536, 768], maxInputTokens: 8192 },
47-
"models/text-embedding-004": { dimension: 768, maxInputTokens: 2048 },
48-
"models/embedding-001": { dimension: 768, maxInputTokens: 2048 },
51+
"gemini-embedding-exp-03-07": {
52+
dimension: 3072,
53+
supportDimensions: [3072, 1536, 768],
54+
maxInputTokens: 8192,
55+
supportsTaskType: true,
56+
},
57+
"models/text-embedding-004": { dimension: 768, maxInputTokens: 2048, supportsTaskType: false },
58+
"models/embedding-001": { dimension: 768, maxInputTokens: 2048, supportsTaskType: false },
4959
},
5060
}
5161

webview-ui/src/components/settings/CodeIndexSettings.tsx

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ export const CodeIndexSettings: React.FC<CodeIndexSettingsProps> = ({
160160
gemini: baseSchema.extend({
161161
codebaseIndexEmbedderProvider: z.literal("gemini"),
162162
geminiApiKey: z.string().min(1, "Gemini API key is required"),
163-
geminiEmbeddingTaskType: z.string().min(1, "Gemini Task Type is required"),
163+
geminiEmbeddingTaskType: z.string().optional(),
164164
geminiEmbeddingDimension: z
165165
.number()
166166
.int()
@@ -460,51 +460,54 @@ export const CodeIndexSettings: React.FC<CodeIndexSettingsProps> = ({
460460
style={{ width: "100%" }}></VSCodeTextField>
461461
</div>
462462
</div>
463-
<div className="flex flex-col gap-3">
464-
<div className="flex items-center gap-4 font-bold">
465-
<div>{t("settings:codeIndex.embeddingTaskType")}</div>
466-
</div>
467-
<div>
468-
<div className="flex items-center gap-2">
469-
<Select
470-
value={
471-
codebaseIndexConfig?.geminiEmbeddingTaskType || "CODE_RETRIEVAL_QUERY"
472-
}
473-
onValueChange={(value) =>
474-
setCachedStateField("codebaseIndexConfig", {
475-
...codebaseIndexConfig,
476-
geminiEmbeddingTaskType: value,
477-
})
478-
}>
479-
<SelectTrigger className="w-full">
480-
<SelectValue
481-
placeholder={t("settings:codeIndex.selectTaskTypePlaceholder")}
482-
/>
483-
</SelectTrigger>
484-
<SelectContent>
485-
<SelectItem value="CODE_RETRIEVAL_QUERY">
486-
{t("settings:codeIndex.selectTaskType.codeRetrievalQuery")}
487-
</SelectItem>
488-
<SelectItem value="RETRIEVAL_DOCUMENT">
489-
{t("settings:codeIndex.selectTaskType.retrievalDocument")}
490-
</SelectItem>
491-
<SelectItem value="RETRIEVAL_QUERY">
492-
{t("settings:codeIndex.selectTaskType.retrievalQuery")}
493-
</SelectItem>
494-
<SelectItem value="SEMANTIC_SIMILARITY">
495-
{t("settings:codeIndex.selectTaskType.semanticSimilarity")}
496-
</SelectItem>
497-
<SelectItem value="CLASSIFICATION">
498-
{t("settings:codeIndex.selectTaskType.classification")}
499-
</SelectItem>
500-
<SelectItem value="CLUSTERING">
501-
{t("settings:codeIndex.selectTaskType.clustering")}
502-
</SelectItem>
503-
</SelectContent>
504-
</Select>
463+
{geminiModelProfileForDim?.supportsTaskType && (
464+
<div className="flex flex-col gap-3">
465+
<div className="flex items-center gap-4 font-bold">
466+
<div>{t("settings:codeIndex.embeddingTaskType")}</div>
467+
</div>
468+
<div>
469+
<div className="flex items-center gap-2">
470+
<Select
471+
value={
472+
codebaseIndexConfig?.geminiEmbeddingTaskType ||
473+
"CODE_RETRIEVAL_QUERY"
474+
}
475+
onValueChange={(value) =>
476+
setCachedStateField("codebaseIndexConfig", {
477+
...codebaseIndexConfig,
478+
geminiEmbeddingTaskType: value,
479+
})
480+
}>
481+
<SelectTrigger className="w-full">
482+
<SelectValue
483+
placeholder={t("settings:codeIndex.selectTaskTypePlaceholder")}
484+
/>
485+
</SelectTrigger>
486+
<SelectContent>
487+
<SelectItem value="CODE_RETRIEVAL_QUERY">
488+
{t("settings:codeIndex.selectTaskType.codeRetrievalQuery")}
489+
</SelectItem>
490+
<SelectItem value="RETRIEVAL_DOCUMENT">
491+
{t("settings:codeIndex.selectTaskType.retrievalDocument")}
492+
</SelectItem>
493+
<SelectItem value="RETRIEVAL_QUERY">
494+
{t("settings:codeIndex.selectTaskType.retrievalQuery")}
495+
</SelectItem>
496+
<SelectItem value="SEMANTIC_SIMILARITY">
497+
{t("settings:codeIndex.selectTaskType.semanticSimilarity")}
498+
</SelectItem>
499+
<SelectItem value="CLASSIFICATION">
500+
{t("settings:codeIndex.selectTaskType.classification")}
501+
</SelectItem>
502+
<SelectItem value="CLUSTERING">
503+
{t("settings:codeIndex.selectTaskType.clustering")}
504+
</SelectItem>
505+
</SelectContent>
506+
</Select>
507+
</div>
505508
</div>
506509
</div>
507-
</div>
510+
)}
508511
{currentProvider === "gemini" &&
509512
geminiModelProfileForDim &&
510513
geminiSupportedDims &&

0 commit comments

Comments
 (0)