diff --git a/packages/types/src/codebase-index.ts b/packages/types/src/codebase-index.ts index 89d5b168d7..5102af8444 100644 --- a/packages/types/src/codebase-index.ts +++ b/packages/types/src/codebase-index.ts @@ -21,7 +21,9 @@ export const CODEBASE_INDEX_DEFAULTS = { export const codebaseIndexConfigSchema = z.object({ codebaseIndexEnabled: z.boolean().optional(), codebaseIndexQdrantUrl: z.string().optional(), - codebaseIndexEmbedderProvider: z.enum(["openai", "ollama", "openai-compatible", "gemini", "mistral"]).optional(), + codebaseIndexEmbedderProvider: z + .enum(["openai", "ollama", "openai-compatible", "gemini", "mistral", "vertex"]) + .optional(), codebaseIndexEmbedderBaseUrl: z.string().optional(), codebaseIndexEmbedderModelId: z.string().optional(), codebaseIndexEmbedderModelDimension: z.number().optional(), @@ -34,6 +36,9 @@ export const codebaseIndexConfigSchema = z.object({ // OpenAI Compatible specific fields codebaseIndexOpenAiCompatibleBaseUrl: z.string().optional(), codebaseIndexOpenAiCompatibleModelDimension: z.number().optional(), + // Vertex AI specific fields + codebaseIndexVertexProjectId: z.string().optional(), + codebaseIndexVertexLocation: z.string().optional(), }) export type CodebaseIndexConfig = z.infer @@ -48,6 +53,7 @@ export const codebaseIndexModelsSchema = z.object({ "openai-compatible": z.record(z.string(), z.object({ dimension: z.number() })).optional(), gemini: z.record(z.string(), z.object({ dimension: z.number() })).optional(), mistral: z.record(z.string(), z.object({ dimension: z.number() })).optional(), + vertex: z.record(z.string(), z.object({ dimension: z.number() })).optional(), }) export type CodebaseIndexModels = z.infer @@ -64,6 +70,9 @@ export const codebaseIndexProviderSchema = z.object({ codebaseIndexOpenAiCompatibleModelDimension: z.number().optional(), codebaseIndexGeminiApiKey: z.string().optional(), codebaseIndexMistralApiKey: z.string().optional(), + codebaseIndexVertexApiKey: z.string().optional(), + codebaseIndexVertexJsonCredentials: z.string().optional(), + codebaseIndexVertexKeyFile: z.string().optional(), }) export type CodebaseIndexProvider = z.infer diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index d5e76eccea..96dfbb88d8 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -184,6 +184,9 @@ export const SECRET_STATE_KEYS = [ "codebaseIndexOpenAiCompatibleApiKey", "codebaseIndexGeminiApiKey", "codebaseIndexMistralApiKey", + "codebaseIndexVertexApiKey", + "codebaseIndexVertexJsonCredentials", + "codebaseIndexVertexKeyFile", "huggingFaceApiKey", ] as const satisfies readonly (keyof ProviderSettings)[] export type SecretState = Pick diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 6bcb85e337..8e562d1ffe 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -1554,6 +1554,8 @@ export class ClineProvider codebaseIndexOpenAiCompatibleBaseUrl: codebaseIndexConfig?.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: codebaseIndexConfig?.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: codebaseIndexConfig?.codebaseIndexSearchMinScore, + codebaseIndexVertexProjectId: codebaseIndexConfig?.codebaseIndexVertexProjectId, + codebaseIndexVertexLocation: codebaseIndexConfig?.codebaseIndexVertexLocation, }, mdmCompliant: this.checkMdmCompliance(), profileThresholds: profileThresholds ?? {}, @@ -1726,6 +1728,8 @@ export class ClineProvider stateValues.codebaseIndexConfig?.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: stateValues.codebaseIndexConfig?.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: stateValues.codebaseIndexConfig?.codebaseIndexSearchMinScore, + codebaseIndexVertexProjectId: stateValues.codebaseIndexConfig?.codebaseIndexVertexProjectId, + codebaseIndexVertexLocation: stateValues.codebaseIndexConfig?.codebaseIndexVertexLocation, }, profileThresholds: stateValues.profileThresholds ?? {}, // Add diagnostic message settings diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index da73c56920..947c1c5218 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1998,6 +1998,8 @@ export const webviewMessageHandler = async ( codebaseIndexOpenAiCompatibleBaseUrl: settings.codebaseIndexOpenAiCompatibleBaseUrl, codebaseIndexSearchMaxResults: settings.codebaseIndexSearchMaxResults, codebaseIndexSearchMinScore: settings.codebaseIndexSearchMinScore, + codebaseIndexVertexProjectId: settings.codebaseIndexVertexProjectId, + codebaseIndexVertexLocation: settings.codebaseIndexVertexLocation, } // Save global state first @@ -2028,6 +2030,24 @@ export const webviewMessageHandler = async ( settings.codebaseIndexMistralApiKey, ) } + if (settings.codebaseIndexVertexApiKey !== undefined) { + await provider.contextProxy.storeSecret( + "codebaseIndexVertexApiKey", + settings.codebaseIndexVertexApiKey, + ) + } + if (settings.codebaseIndexVertexJsonCredentials !== undefined) { + await provider.contextProxy.storeSecret( + "codebaseIndexVertexJsonCredentials", + settings.codebaseIndexVertexJsonCredentials, + ) + } + if (settings.codebaseIndexVertexKeyFile !== undefined) { + await provider.contextProxy.storeSecret( + "codebaseIndexVertexKeyFile", + settings.codebaseIndexVertexKeyFile, + ) + } // Send success response first - settings are saved regardless of validation await provider.postMessageToWebview({ @@ -2149,6 +2169,11 @@ export const webviewMessageHandler = async ( )) const hasGeminiApiKey = !!(await provider.context.secrets.get("codebaseIndexGeminiApiKey")) const hasMistralApiKey = !!(await provider.context.secrets.get("codebaseIndexMistralApiKey")) + const hasVertexApiKey = !!(await provider.context.secrets.get("codebaseIndexVertexApiKey")) + const hasVertexJsonCredentials = !!(await provider.context.secrets.get( + "codebaseIndexVertexJsonCredentials", + )) + const hasVertexKeyFile = !!(await provider.context.secrets.get("codebaseIndexVertexKeyFile")) provider.postMessageToWebview({ type: "codeIndexSecretStatus", @@ -2158,6 +2183,9 @@ export const webviewMessageHandler = async ( hasOpenAiCompatibleApiKey, hasGeminiApiKey, hasMistralApiKey, + hasVertexApiKey, + hasVertexJsonCredentials, + hasVertexKeyFile, }, }) break diff --git a/src/i18n/locales/en/embeddings.json b/src/i18n/locales/en/embeddings.json index 66465d8c35..c2b31e4b9b 100644 --- a/src/i18n/locales/en/embeddings.json +++ b/src/i18n/locales/en/embeddings.json @@ -47,6 +47,7 @@ "openAiCompatibleConfigMissing": "OpenAI Compatible configuration missing for embedder creation", "geminiConfigMissing": "Gemini configuration missing for embedder creation", "mistralConfigMissing": "Mistral configuration missing for embedder creation", + "vertexConfigMissing": "Vertex AI configuration missing for embedder creation", "invalidEmbedderType": "Invalid embedder type configured: {{embedderProvider}}", "vectorDimensionNotDeterminedOpenAiCompatible": "Could not determine vector dimension for model '{{modelId}}' with provider '{{provider}}'. Please ensure the 'Embedding Dimension' is correctly set in the OpenAI-Compatible provider settings.", "vectorDimensionNotDetermined": "Could not determine vector dimension for model '{{modelId}}' with provider '{{provider}}'. Check model profiles or configuration.", diff --git a/src/services/code-index/constants/index.ts b/src/services/code-index/constants/index.ts index 6f0e0fe7e6..1356b6b2f6 100644 --- a/src/services/code-index/constants/index.ts +++ b/src/services/code-index/constants/index.ts @@ -29,3 +29,6 @@ export const BATCH_PROCESSING_CONCURRENCY = 10 /**Gemini Embedder */ export const GEMINI_MAX_ITEM_TOKENS = 2048 + +/**Vertex AI Embedder */ +export const VERTEX_MAX_ITEM_TOKENS = 2048 diff --git a/src/services/code-index/embedders/__tests__/vertex.spec.ts b/src/services/code-index/embedders/__tests__/vertex.spec.ts new file mode 100644 index 0000000000..aa91ebeac4 --- /dev/null +++ b/src/services/code-index/embedders/__tests__/vertex.spec.ts @@ -0,0 +1,362 @@ +import { vitest, describe, it, expect, beforeEach } from "vitest" +import { VertexEmbedder } from "../vertex" +import { GoogleGenAI } from "@google/genai" + +// Mock the @google/genai library +vitest.mock("@google/genai") + +// Mock TelemetryService +vitest.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureEvent: vitest.fn(), + }, + }, +})) + +// Mock i18n +vitest.mock("../../../../i18n", () => ({ + t: (key: string, params?: Record) => { + const translations: Record = { + "validation.apiKeyRequired": "API key is required", + "embeddings:validation.authenticationFailed": "Authentication failed", + "embeddings:validation.connectionFailed": "Connection failed", + "embeddings:validation.modelNotAvailable": "Model not available", + "embeddings:validation.unexpectedError": "Unexpected error", + "embeddings:validation.vertexAuthRequired": "At least one authentication method is required for Vertex AI", + "embeddings:validation.noEmbeddingsReturned": "No embeddings returned", + "embeddings:validation.configurationError": "Configuration error", + } + return translations[key] || key + }, +})) + +// Mock safeJsonParse +vitest.mock("../../../shared/safeJsonParse", () => ({ + safeJsonParse: (json: string, defaultValue: any) => { + try { + return JSON.parse(json) + } catch { + return defaultValue + } + }, +})) + +describe("VertexEmbedder", () => { + let embedder: VertexEmbedder + let mockClient: any + let mockModel: any + let mockEmbedContent: any + + beforeEach(() => { + vitest.clearAllMocks() + + // Setup mock for embedContent + mockEmbedContent = vitest.fn() + mockModel = { + embedContent: mockEmbedContent, + } + mockClient = { + models: { + embedContent: mockEmbedContent, + }, + } + + // Mock GoogleGenAI constructor + ;(GoogleGenAI as any).mockImplementation(() => mockClient) + }) + + describe("constructor", () => { + it("should create an instance with API key authentication", () => { + // Act + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(GoogleGenAI).toHaveBeenCalledWith({ apiKey: "test-api-key" }) + expect(embedder.embedderInfo.name).toBe("vertex") + }) + + it("should create an instance with JSON credentials authentication", () => { + // Act + embedder = new VertexEmbedder({ + jsonCredentials: '{"type": "service_account"}', + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(GoogleGenAI).toHaveBeenCalledWith({ + vertexai: true, + project: "test-project", + location: "us-central1", + googleAuthOptions: { + credentials: { type: "service_account" }, + }, + }) + expect(embedder.embedderInfo.name).toBe("vertex") + }) + + it("should create an instance with key file authentication", () => { + // Act + embedder = new VertexEmbedder({ + keyFile: "/path/to/keyfile.json", + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(GoogleGenAI).toHaveBeenCalledWith({ + vertexai: true, + project: "test-project", + location: "us-central1", + googleAuthOptions: { keyFile: "/path/to/keyfile.json" }, + }) + expect(embedder.embedderInfo.name).toBe("vertex") + }) + + it("should create an instance with application default credentials", () => { + // Act + embedder = new VertexEmbedder({ + apiKey: "", // Empty string to trigger ADC path + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(GoogleGenAI).toHaveBeenCalledWith({ + vertexai: true, + project: "test-project", + location: "us-central1", + }) + expect(embedder.embedderInfo.name).toBe("vertex") + }) + + it("should use default model when not specified", () => { + // Act + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(embedder["modelId"]).toBe("text-embedding-004") + }) + + it("should use specified model", () => { + // Act + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + modelId: "text-multilingual-embedding-002", + projectId: "test-project", + location: "us-central1", + }) + + // Assert + expect(embedder["modelId"]).toBe("text-multilingual-embedding-002") + }) + }) + + describe("embedderInfo", () => { + it("should return correct embedder info", () => { + // Arrange + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + + // Act + const info = embedder.embedderInfo + + // Assert + expect(info).toEqual({ + name: "vertex", + }) + }) + }) + + describe("createEmbeddings", () => { + beforeEach(() => { + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + }) + + it("should create embeddings for single text", async () => { + // Arrange + const texts = ["test text"] + const mockResponse = { + embeddings: [{ values: [0.1, 0.2, 0.3] }], + } + mockEmbedContent.mockResolvedValue(mockResponse) + + // Act + const result = await embedder.createEmbeddings(texts) + + // Assert + expect(mockEmbedContent).toHaveBeenCalledWith({ + model: "text-embedding-004", + contents: [{ parts: [{ text: "test text" }] }], + }) + expect(result).toEqual({ + embeddings: [[0.1, 0.2, 0.3]], + }) + }) + + it("should create embeddings for multiple texts in batches", async () => { + // Arrange + const texts = ["text1", "text2", "text3"] + const mockResponse = { + embeddings: [{ values: [0.1, 0.2] }, { values: [0.3, 0.4] }, { values: [0.5, 0.6] }], + } + mockEmbedContent.mockResolvedValue(mockResponse) + + // Act + const result = await embedder.createEmbeddings(texts) + + // Assert + expect(mockEmbedContent).toHaveBeenCalledTimes(1) + expect(mockEmbedContent).toHaveBeenCalledWith({ + model: "text-embedding-004", + contents: [ + { parts: [{ text: "text1" }] }, + { parts: [{ text: "text2" }] }, + { parts: [{ text: "text3" }] }, + ], + }) + expect(result).toEqual({ + embeddings: [ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + ], + }) + }) + + it("should use custom model when provided", async () => { + // Arrange + const texts = ["test text"] + const mockResponse = { + embeddings: [{ values: [0.1, 0.2] }], + } + mockEmbedContent.mockResolvedValue(mockResponse) + + // Act + await embedder.createEmbeddings(texts, "text-multilingual-embedding-002") + + // Assert + expect(mockEmbedContent).toHaveBeenCalledWith({ + model: "text-multilingual-embedding-002", + contents: [{ parts: [{ text: "test text" }] }], + }) + }) + + it("should handle empty text array", async () => { + // Act + const result = await embedder.createEmbeddings([]) + + // Assert + expect(mockEmbedContent).not.toHaveBeenCalled() + expect(result).toEqual({ embeddings: [] }) + }) + + it("should handle API errors", async () => { + // Arrange + const texts = ["test text"] + const error = new Error("API Error") + mockEmbedContent.mockRejectedValue(error) + + // Act & Assert + await expect(embedder.createEmbeddings(texts)).rejects.toThrow("API Error") + }) + }) + + describe("validateConfiguration", () => { + beforeEach(() => { + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + }) + + it("should validate configuration successfully", async () => { + // Arrange + const mockResponse = { + embeddings: [{ values: [0.1, 0.2] }], + } + mockEmbedContent.mockResolvedValue(mockResponse) + + // Act + const result = await embedder.validateConfiguration() + + // Assert + expect(mockEmbedContent).toHaveBeenCalledWith({ + model: "text-embedding-004", + contents: [{ parts: [{ text: "test" }] }], + }) + expect(result).toEqual({ valid: true }) + }) + + it("should handle unexpected errors", async () => { + // Arrange + const error = new Error("Something went wrong") + mockEmbedContent.mockRejectedValue(error) + + // Act + const result = await embedder.validateConfiguration() + + // Assert + expect(result).toEqual({ + valid: false, + error: "Something went wrong", + }) + }) + }) + + describe("createBatches", () => { + beforeEach(() => { + embedder = new VertexEmbedder({ + apiKey: "test-api-key", + projectId: "test-project", + location: "us-central1", + }) + }) + + it("should create batches respecting token limits", () => { + // Arrange + const texts = [ + "short text", + "another short text", + "a".repeat(5000), // Long text + "more text", + ] + + // Act + const batches = embedder["createBatches"](texts) + + // Assert + expect(batches.length).toBe(1) + expect(batches[0].length).toBe(4) // All texts in one batch (under 100 limit) + }) + + it("should handle all oversized texts", () => { + // Arrange + const texts = ["a".repeat(10000), "b".repeat(10000)] + + // Act + const batches = embedder["createBatches"](texts) + + // Assert + expect(batches.length).toBe(1) + expect(batches[0].length).toBe(2) // Both texts in one batch + }) + }) +}) diff --git a/src/services/code-index/embedders/vertex.ts b/src/services/code-index/embedders/vertex.ts new file mode 100644 index 0000000000..11e2e22cde --- /dev/null +++ b/src/services/code-index/embedders/vertex.ts @@ -0,0 +1,185 @@ +import { GoogleGenAI } from "@google/genai" +import type { JWTInput } from "google-auth-library" +import { IEmbedder, EmbeddingResponse, EmbedderInfo } from "../interfaces/embedder" +import { VERTEX_MAX_ITEM_TOKENS } from "../constants" +import { t } from "../../../i18n" +import { TelemetryEventName } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" +import { safeJsonParse } from "../../../shared/safeJsonParse" + +/** + * Vertex AI embedder implementation using the @google/genai library + * with support for multiple authentication methods. + * + * Supported models: + * - text-embedding-004 (dimension: 768) + * - text-multilingual-embedding-002 (dimension: 768) + * - textembedding-gecko@003 (dimension: 768) + * - textembedding-gecko-multilingual@001 (dimension: 768) + */ +export class VertexEmbedder implements IEmbedder { + private readonly client: GoogleGenAI + private static readonly DEFAULT_MODEL = "text-embedding-004" + private readonly modelId: string + private readonly maxItemTokens: number + + /** + * Creates a new Vertex AI embedder + * @param options Configuration options including authentication methods + */ + constructor(options: { + apiKey?: string + jsonCredentials?: string + keyFile?: string + projectId: string + location: string + modelId?: string + }) { + const { apiKey, jsonCredentials, keyFile, projectId, location, modelId } = options + + // Validate required fields + if (!projectId) { + throw new Error("Project ID is required for Vertex AI") + } + if (!location) { + throw new Error("Location is required for Vertex AI") + } + + // Use provided model or default + this.modelId = modelId || VertexEmbedder.DEFAULT_MODEL + this.maxItemTokens = VERTEX_MAX_ITEM_TOKENS + + // Create the GoogleGenAI client with appropriate auth + if (jsonCredentials) { + this.client = new GoogleGenAI({ + vertexai: true, + project: projectId, + location, + googleAuthOptions: { + credentials: safeJsonParse(jsonCredentials, undefined), + }, + }) + } else if (keyFile) { + this.client = new GoogleGenAI({ + vertexai: true, + project: projectId, + location, + googleAuthOptions: { keyFile }, + }) + } else if (apiKey && apiKey.trim() !== "") { + // For API key auth, we use the regular Gemini API endpoint + this.client = new GoogleGenAI({ apiKey }) + } else { + // Default to application default credentials + this.client = new GoogleGenAI({ + vertexai: true, + project: projectId, + location, + }) + } + } + + /** + * Creates embeddings for the given texts using Vertex AI's embedding API + * @param texts Array of text strings to embed + * @param model Optional model identifier (uses constructor model if not provided) + * @returns Promise resolving to embedding response + */ + async createEmbeddings(texts: string[], model?: string): Promise { + try { + const modelToUse = model || this.modelId + + // Batch texts if they exceed token limits + const batches = this.createBatches(texts) + const allEmbeddings: number[][] = [] + + for (const batch of batches) { + const result = await this.client.models.embedContent({ + model: modelToUse, + contents: batch.map((text) => ({ parts: [{ text }] })), + }) + + if (!result.embeddings || result.embeddings.length === 0) { + throw new Error(t("embeddings:validation.noEmbeddingsReturned")) + } + + // Filter out any embeddings without values + const validEmbeddings = result.embeddings + .filter((e) => e.values !== undefined) + .map((e) => e.values as number[]) + + allEmbeddings.push(...validEmbeddings) + } + + return { + embeddings: allEmbeddings, + } + } catch (error) { + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + location: "VertexEmbedder:createEmbeddings", + }) + throw error + } + } + + /** + * Creates batches of texts that respect token limits + */ + private createBatches(texts: string[]): string[][] { + // Simple batching - in production, you'd want to estimate tokens + const batchSize = 100 // Vertex AI typically supports up to 100 texts per batch + const batches: string[][] = [] + + for (let i = 0; i < texts.length; i += batchSize) { + batches.push(texts.slice(i, i + batchSize)) + } + + return batches + } + + /** + * Validates the Vertex AI embedder configuration + * @returns Promise resolving to validation result with success status and optional error message + */ + async validateConfiguration(): Promise<{ valid: boolean; error?: string }> { + try { + // Test with a simple embedding request + const testText = "test" + const result = await this.client.models.embedContent({ + model: this.modelId, + contents: [{ parts: [{ text: testText }] }], + }) + + if (!result.embeddings || result.embeddings.length === 0) { + return { + valid: false, + error: t("embeddings:validation.noEmbeddingsReturned"), + } + } + + return { valid: true } + } catch (error) { + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + location: "VertexEmbedder:validateConfiguration", + }) + + return { + valid: false, + error: error instanceof Error ? error.message : t("embeddings:validation.configurationError"), + } + } + } + + /** + * Returns information about this embedder + */ + get embedderInfo(): EmbedderInfo { + return { + name: "vertex", + } + } +} diff --git a/src/services/code-index/interfaces/config.ts b/src/services/code-index/interfaces/config.ts index 9098a60091..69dff5699a 100644 --- a/src/services/code-index/interfaces/config.ts +++ b/src/services/code-index/interfaces/config.ts @@ -14,6 +14,13 @@ export interface CodeIndexConfig { openAiCompatibleOptions?: { baseUrl: string; apiKey: string } geminiOptions?: { apiKey: string } mistralOptions?: { apiKey: string } + vertexOptions?: { + apiKey?: string + jsonCredentials?: string + keyFile?: string + projectId?: string + location?: string + } qdrantUrl?: string qdrantApiKey?: string searchMinScore?: number @@ -35,6 +42,11 @@ export type PreviousConfigSnapshot = { openAiCompatibleApiKey?: string geminiApiKey?: string mistralApiKey?: string + vertexApiKey?: string + vertexJsonCredentials?: string + vertexKeyFile?: string + vertexProjectId?: string + vertexLocation?: string qdrantUrl?: string qdrantApiKey?: string } diff --git a/src/services/code-index/interfaces/embedder.ts b/src/services/code-index/interfaces/embedder.ts index c5653ea2b7..3ea92a3c77 100644 --- a/src/services/code-index/interfaces/embedder.ts +++ b/src/services/code-index/interfaces/embedder.ts @@ -28,7 +28,7 @@ export interface EmbeddingResponse { } } -export type AvailableEmbedders = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" +export type AvailableEmbedders = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vertex" export interface EmbedderInfo { name: AvailableEmbedders diff --git a/src/services/code-index/service-factory.ts b/src/services/code-index/service-factory.ts index 68b0f5c0bc..b9bdb74873 100644 --- a/src/services/code-index/service-factory.ts +++ b/src/services/code-index/service-factory.ts @@ -4,6 +4,7 @@ import { CodeIndexOllamaEmbedder } from "./embedders/ollama" import { OpenAICompatibleEmbedder } from "./embedders/openai-compatible" import { GeminiEmbedder } from "./embedders/gemini" import { MistralEmbedder } from "./embedders/mistral" +import { VertexEmbedder } from "./embedders/vertex" import { EmbedderProvider, getDefaultModelId, getModelDimension } from "../../shared/embeddingModels" import { QdrantVectorStore } from "./vector-store/qdrant-client" import { codeParser, DirectoryScanner, FileWatcher } from "./processors" @@ -70,6 +71,30 @@ export class CodeIndexServiceFactory { throw new Error(t("embeddings:serviceFactory.mistralConfigMissing")) } return new MistralEmbedder(config.mistralOptions.apiKey, config.modelId) + } else if (provider === "vertex") { + const vertexOptions = config.vertexOptions + if (!vertexOptions) { + throw new Error(t("embeddings:serviceFactory.vertexConfigMissing")) + } + + // Validate that at least one auth method is provided + if (!vertexOptions.apiKey && !vertexOptions.jsonCredentials && !vertexOptions.keyFile) { + throw new Error(t("embeddings:serviceFactory.vertexAuthRequired")) + } + + // Validate required fields for Vertex AI + if (!vertexOptions.projectId || !vertexOptions.location) { + throw new Error(t("embeddings:serviceFactory.vertexProjectLocationRequired")) + } + + return new VertexEmbedder({ + apiKey: vertexOptions.apiKey, + jsonCredentials: vertexOptions.jsonCredentials, + keyFile: vertexOptions.keyFile, + projectId: vertexOptions.projectId, + location: vertexOptions.location, + modelId: config.modelId, + }) } throw new Error( diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 1304e4c7d5..540e802037 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -249,13 +249,15 @@ export interface WebviewMessage { // Global state settings codebaseIndexEnabled: boolean codebaseIndexQdrantUrl: string - codebaseIndexEmbedderProvider: "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" + codebaseIndexEmbedderProvider: "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vertex" codebaseIndexEmbedderBaseUrl?: string codebaseIndexEmbedderModelId: string codebaseIndexEmbedderModelDimension?: number // Generic dimension for all providers codebaseIndexOpenAiCompatibleBaseUrl?: string codebaseIndexSearchMaxResults?: number codebaseIndexSearchMinScore?: number + codebaseIndexVertexProjectId?: string + codebaseIndexVertexLocation?: string // Secret settings codeIndexOpenAiKey?: string @@ -263,6 +265,9 @@ export interface WebviewMessage { codebaseIndexOpenAiCompatibleApiKey?: string codebaseIndexGeminiApiKey?: string codebaseIndexMistralApiKey?: string + codebaseIndexVertexApiKey?: string + codebaseIndexVertexJsonCredentials?: string + codebaseIndexVertexKeyFile?: string } } diff --git a/src/shared/embeddingModels.ts b/src/shared/embeddingModels.ts index a3cd61e659..b14ab68981 100644 --- a/src/shared/embeddingModels.ts +++ b/src/shared/embeddingModels.ts @@ -2,7 +2,7 @@ * Defines profiles for different embedding models, including their dimensions. */ -export type EmbedderProvider = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" // Add other providers as needed +export type EmbedderProvider = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vertex" // Add other providers as needed export interface EmbeddingModelProfile { dimension: number @@ -53,6 +53,12 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = { mistral: { "codestral-embed-2505": { dimension: 1536, scoreThreshold: 0.4 }, }, + vertex: { + "text-embedding-004": { dimension: 768, scoreThreshold: 0.4 }, + "text-multilingual-embedding-002": { dimension: 768, scoreThreshold: 0.4 }, + "textembedding-gecko@003": { dimension: 768, scoreThreshold: 0.4 }, + "textembedding-gecko-multilingual@001": { dimension: 768, scoreThreshold: 0.4 }, + }, } /** @@ -143,6 +149,9 @@ export function getDefaultModelId(provider: EmbedderProvider): string { case "mistral": return "codestral-embed-2505" + case "vertex": + return "text-embedding-004" + default: // Fallback for unknown providers console.warn(`Unknown provider for default model ID: ${provider}. Falling back to OpenAI default.`) diff --git a/webview-ui/src/components/chat/CodeIndexPopover.tsx b/webview-ui/src/components/chat/CodeIndexPopover.tsx index d7683e8c7e..6c26afc448 100644 --- a/webview-ui/src/components/chat/CodeIndexPopover.tsx +++ b/webview-ui/src/components/chat/CodeIndexPopover.tsx @@ -69,6 +69,10 @@ interface LocalCodeIndexSettings { codebaseIndexOpenAiCompatibleApiKey?: string codebaseIndexGeminiApiKey?: string codebaseIndexMistralApiKey?: string + codebaseIndexVertexJsonCredentials?: string + codebaseIndexVertexKeyFile?: string + codebaseIndexVertexProjectId?: string + codebaseIndexVertexLocation?: string } // Validation schema for codebase index settings @@ -135,6 +139,27 @@ const createValidationSchema = (provider: EmbedderProvider, t: any) => { .min(1, t("settings:codeIndex.validation.modelSelectionRequired")), }) + case "vertex": + return baseSchema + .extend({ + // At least one auth method is required + codebaseIndexVertexJsonCredentials: z.string().optional(), + codebaseIndexVertexKeyFile: z.string().optional(), + codebaseIndexVertexProjectId: z + .string() + .min(1, t("settings:codeIndex.validation.vertexProjectIdRequired")), + codebaseIndexVertexLocation: z + .string() + .min(1, t("settings:codeIndex.validation.vertexLocationRequired")), + codebaseIndexEmbedderModelId: z + .string() + .min(1, t("settings:codeIndex.validation.modelSelectionRequired")), + }) + .refine((data) => data.codebaseIndexVertexJsonCredentials || data.codebaseIndexVertexKeyFile, { + message: t("settings:codeIndex.validation.vertexAuthRequired"), + path: ["codebaseIndexVertexJsonCredentials"], + }) + default: return baseSchema } @@ -179,6 +204,10 @@ export const CodeIndexPopover: React.FC = ({ codebaseIndexOpenAiCompatibleApiKey: "", codebaseIndexGeminiApiKey: "", codebaseIndexMistralApiKey: "", + codebaseIndexVertexJsonCredentials: "", + codebaseIndexVertexKeyFile: "", + codebaseIndexVertexProjectId: "", + codebaseIndexVertexLocation: "us-central1", }) // Initial settings state - stores the settings when popover opens @@ -213,6 +242,10 @@ export const CodeIndexPopover: React.FC = ({ codebaseIndexOpenAiCompatibleApiKey: "", codebaseIndexGeminiApiKey: "", codebaseIndexMistralApiKey: "", + codebaseIndexVertexJsonCredentials: "", + codebaseIndexVertexKeyFile: "", + codebaseIndexVertexProjectId: codebaseIndexConfig.codebaseIndexVertexProjectId || "", + codebaseIndexVertexLocation: codebaseIndexConfig.codebaseIndexVertexLocation || "us-central1", } setInitialSettings(settings) setCurrentSettings(settings) @@ -307,6 +340,17 @@ export const CodeIndexPopover: React.FC = ({ if (!prev.codebaseIndexMistralApiKey || prev.codebaseIndexMistralApiKey === SECRET_PLACEHOLDER) { updated.codebaseIndexMistralApiKey = secretStatus.hasMistralApiKey ? SECRET_PLACEHOLDER : "" } + if ( + !prev.codebaseIndexVertexJsonCredentials || + prev.codebaseIndexVertexJsonCredentials === SECRET_PLACEHOLDER + ) { + updated.codebaseIndexVertexJsonCredentials = secretStatus.hasVertexJsonCredentials + ? SECRET_PLACEHOLDER + : "" + } + if (!prev.codebaseIndexVertexKeyFile || prev.codebaseIndexVertexKeyFile === SECRET_PLACEHOLDER) { + updated.codebaseIndexVertexKeyFile = secretStatus.hasVertexKeyFile ? SECRET_PLACEHOLDER : "" + } return updated } @@ -379,7 +423,9 @@ export const CodeIndexPopover: React.FC = ({ key === "codeIndexOpenAiKey" || key === "codebaseIndexOpenAiCompatibleApiKey" || key === "codebaseIndexGeminiApiKey" || - key === "codebaseIndexMistralApiKey" + key === "codebaseIndexMistralApiKey" || + key === "codebaseIndexVertexJsonCredentials" || + key === "codebaseIndexVertexKeyFile" ) { dataToValidate[key] = "placeholder-valid" } @@ -624,6 +670,9 @@ export const CodeIndexPopover: React.FC = ({ {t("settings:codeIndex.mistralProvider")} + + {t("settings:codeIndex.vertexProvider")} + @@ -1016,6 +1065,161 @@ export const CodeIndexPopover: React.FC = ({ )} + {currentSettings.codebaseIndexEmbedderProvider === "vertex" && ( + <> +
+
{t("settings:providers.googleCloudSetup.title")}
+
+ + {t("settings:providers.googleCloudSetup.step1")} + +
+
+ + {t("settings:providers.googleCloudSetup.step2")} + +
+
+ + {t("settings:providers.googleCloudSetup.step3")} + +
+
+ +
+ + + updateSetting( + "codebaseIndexVertexJsonCredentials", + e.target.value, + ) + } + placeholder={t("settings:placeholders.credentialsJson")} + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexVertexJsonCredentials, + })} + /> + {formErrors.codebaseIndexVertexJsonCredentials && ( +

+ {formErrors.codebaseIndexVertexJsonCredentials} +

+ )} +
+ +
+ + + updateSetting("codebaseIndexVertexKeyFile", e.target.value) + } + placeholder={t("settings:placeholders.keyFilePath")} + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexVertexKeyFile, + })} + /> + {formErrors.codebaseIndexVertexKeyFile && ( +

+ {formErrors.codebaseIndexVertexKeyFile} +

+ )} +
+ +
+ + + updateSetting("codebaseIndexVertexProjectId", e.target.value) + } + placeholder={t("settings:placeholders.projectId")} + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexVertexProjectId, + })} + /> + {formErrors.codebaseIndexVertexProjectId && ( +

+ {formErrors.codebaseIndexVertexProjectId} +

+ )} +
+ +
+ + + updateSetting("codebaseIndexVertexLocation", e.target.value) + } + placeholder={t("settings:codeIndex.vertexLocationPlaceholder")} + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexVertexLocation, + })} + /> + {formErrors.codebaseIndexVertexLocation && ( +

+ {formErrors.codebaseIndexVertexLocation} +

+ )} +
+ +
+ + + updateSetting("codebaseIndexEmbedderModelId", e.target.value) + } + className={cn("w-full", { + "border-red-500": formErrors.codebaseIndexEmbedderModelId, + })}> + + {t("settings:codeIndex.selectModel")} + + {getAvailableModels().map((modelId) => { + const model = + codebaseIndexModels?.[ + currentSettings.codebaseIndexEmbedderProvider + ]?.[modelId] + return ( + + {modelId}{" "} + {model + ? t("settings:codeIndex.modelDimensions", { + dimension: model.dimension, + }) + : ""} + + ) + })} + + {formErrors.codebaseIndexEmbedderModelId && ( +

+ {formErrors.codebaseIndexEmbedderModelId} +

+ )} +
+ + )} + {/* Qdrant Settings */}