diff --git a/src/i18n/locales/en/embeddings.json b/src/i18n/locales/en/embeddings.json index fc902cadc165..210f6219fb3b 100644 --- a/src/i18n/locales/en/embeddings.json +++ b/src/i18n/locales/en/embeddings.json @@ -17,6 +17,12 @@ "modelNotEmbeddingCapable": "Ollama model is not embedding capable: {{modelId}}", "hostNotFound": "Ollama host not found: {{baseUrl}}" }, + "bedrock": { + "invalidResponseFormat": "Invalid response format from AWS Bedrock", + "invalidCredentials": "Invalid AWS credentials. Please check your AWS configuration.", + "accessDenied": "Access denied to AWS Bedrock service. Please check your IAM permissions.", + "modelNotFound": "Model {{model}} not found in AWS Bedrock" + }, "scanner": { "unknownErrorProcessingFile": "Unknown error processing file {{filePath}}", "unknownErrorDeletingPoints": "Unknown error deleting points for {{filePath}}", @@ -48,6 +54,7 @@ "geminiConfigMissing": "Gemini configuration missing for embedder creation", "mistralConfigMissing": "Mistral configuration missing for embedder creation", "vercelAiGatewayConfigMissing": "Vercel AI Gateway configuration missing for embedder creation", + "bedrockConfigMissing": "AWS Bedrock configuration missing for embedder creation", "invalidEmbedderType": "Invalid embedder type configured: {{embedderProvider}}", "vectorDimensionNotDeterminedOpenAiCompatible": "Could not determine vector dimension for model '{{modelId}}' with provider '{{provider}}'. Please ensure the 'Embedding Dimension' is correctly set in the OpenAI-Compatible provider settings.", "vectorDimensionNotDetermined": "Could not determine vector dimension for model '{{modelId}}' with provider '{{provider}}'. Check model profiles or configuration.", diff --git a/src/services/code-index/config-manager.ts b/src/services/code-index/config-manager.ts index 2c0e8bb5c9e4..a1fb2f4a0e45 100644 --- a/src/services/code-index/config-manager.ts +++ b/src/services/code-index/config-manager.ts @@ -20,6 +20,7 @@ export class CodeIndexConfigManager { private geminiOptions?: { apiKey: string } private mistralOptions?: { apiKey: string } private vercelAiGatewayOptions?: { apiKey: string } + private bedrockOptions?: { region: string; profile?: string } private qdrantUrl?: string = "http://localhost:6333" private qdrantApiKey?: string private searchMinScore?: number @@ -49,8 +50,13 @@ export class CodeIndexConfigManager { codebaseIndexEmbedderProvider: "openai", codebaseIndexEmbedderBaseUrl: "", codebaseIndexEmbedderModelId: "", + codebaseIndexEmbedderModelDimension: undefined, codebaseIndexSearchMinScore: undefined, codebaseIndexSearchMaxResults: undefined, + codebaseIndexOpenAiCompatibleBaseUrl: "", + codebaseIndexOpenAiCompatibleModelDimension: undefined, + codebaseIndexBedrockRegion: "us-east-1", + codebaseIndexBedrockProfile: "", } const { @@ -66,11 +72,13 @@ export class CodeIndexConfigManager { const openAiKey = this.contextProxy?.getSecret("codeIndexOpenAiKey") ?? "" const qdrantApiKey = this.contextProxy?.getSecret("codeIndexQdrantApiKey") ?? "" // Fix: Read OpenAI Compatible settings from the correct location within codebaseIndexConfig - const openAiCompatibleBaseUrl = codebaseIndexConfig.codebaseIndexOpenAiCompatibleBaseUrl ?? "" + const openAiCompatibleBaseUrl = (codebaseIndexConfig as any).codebaseIndexOpenAiCompatibleBaseUrl ?? "" const openAiCompatibleApiKey = this.contextProxy?.getSecret("codebaseIndexOpenAiCompatibleApiKey") ?? "" const geminiApiKey = this.contextProxy?.getSecret("codebaseIndexGeminiApiKey") ?? "" const mistralApiKey = this.contextProxy?.getSecret("codebaseIndexMistralApiKey") ?? "" const vercelAiGatewayApiKey = this.contextProxy?.getSecret("codebaseIndexVercelAiGatewayApiKey") ?? "" + const bedrockRegion = (codebaseIndexConfig as any).codebaseIndexBedrockRegion ?? "us-east-1" + const bedrockProfile = (codebaseIndexConfig as any).codebaseIndexBedrockProfile ?? "" // Update instance variables with configuration this.codebaseIndexEnabled = codebaseIndexEnabled ?? true @@ -80,7 +88,7 @@ export class CodeIndexConfigManager { this.searchMaxResults = codebaseIndexSearchMaxResults // Validate and set model dimension - const rawDimension = codebaseIndexConfig.codebaseIndexEmbedderModelDimension + const rawDimension = (codebaseIndexConfig as any).codebaseIndexEmbedderModelDimension if (rawDimension !== undefined && rawDimension !== null) { const dimension = Number(rawDimension) if (!isNaN(dimension) && dimension > 0) { @@ -108,6 +116,8 @@ export class CodeIndexConfigManager { this.embedderProvider = "mistral" } else if (codebaseIndexEmbedderProvider === "vercel-ai-gateway") { this.embedderProvider = "vercel-ai-gateway" + } else if ((codebaseIndexEmbedderProvider as string) === "bedrock") { + this.embedderProvider = "bedrock" } else { this.embedderProvider = "openai" } @@ -129,6 +139,9 @@ export class CodeIndexConfigManager { this.geminiOptions = geminiApiKey ? { apiKey: geminiApiKey } : undefined this.mistralOptions = mistralApiKey ? { apiKey: mistralApiKey } : undefined this.vercelAiGatewayOptions = vercelAiGatewayApiKey ? { apiKey: vercelAiGatewayApiKey } : undefined + this.bedrockOptions = bedrockRegion + ? { region: bedrockRegion, profile: bedrockProfile || undefined } + : undefined } /** @@ -147,6 +160,7 @@ export class CodeIndexConfigManager { geminiOptions?: { apiKey: string } mistralOptions?: { apiKey: string } vercelAiGatewayOptions?: { apiKey: string } + bedrockOptions?: { region: string; profile?: string } qdrantUrl?: string qdrantApiKey?: string searchMinScore?: number @@ -167,6 +181,8 @@ export class CodeIndexConfigManager { geminiApiKey: this.geminiOptions?.apiKey ?? "", mistralApiKey: this.mistralOptions?.apiKey ?? "", vercelAiGatewayApiKey: this.vercelAiGatewayOptions?.apiKey ?? "", + bedrockRegion: this.bedrockOptions?.region ?? "", + bedrockProfile: this.bedrockOptions?.profile ?? "", qdrantUrl: this.qdrantUrl ?? "", qdrantApiKey: this.qdrantApiKey ?? "", } @@ -192,6 +208,7 @@ export class CodeIndexConfigManager { geminiOptions: this.geminiOptions, mistralOptions: this.mistralOptions, vercelAiGatewayOptions: this.vercelAiGatewayOptions, + bedrockOptions: this.bedrockOptions, qdrantUrl: this.qdrantUrl, qdrantApiKey: this.qdrantApiKey, searchMinScore: this.currentSearchMinScore, @@ -234,6 +251,11 @@ export class CodeIndexConfigManager { const qdrantUrl = this.qdrantUrl const isConfigured = !!(apiKey && qdrantUrl) return isConfigured + } else if (this.embedderProvider === "bedrock") { + const region = this.bedrockOptions?.region + const qdrantUrl = this.qdrantUrl + const isConfigured = !!(region && qdrantUrl) + return isConfigured } return false // Should not happen if embedderProvider is always set correctly } @@ -269,6 +291,8 @@ export class CodeIndexConfigManager { const prevGeminiApiKey = prev?.geminiApiKey ?? "" const prevMistralApiKey = prev?.mistralApiKey ?? "" const prevVercelAiGatewayApiKey = prev?.vercelAiGatewayApiKey ?? "" + const prevBedrockRegion = prev?.bedrockRegion ?? "" + const prevBedrockProfile = prev?.bedrockProfile ?? "" const prevQdrantUrl = prev?.qdrantUrl ?? "" const prevQdrantApiKey = prev?.qdrantApiKey ?? "" @@ -307,6 +331,8 @@ export class CodeIndexConfigManager { const currentGeminiApiKey = this.geminiOptions?.apiKey ?? "" const currentMistralApiKey = this.mistralOptions?.apiKey ?? "" const currentVercelAiGatewayApiKey = this.vercelAiGatewayOptions?.apiKey ?? "" + const currentBedrockRegion = this.bedrockOptions?.region ?? "" + const currentBedrockProfile = this.bedrockOptions?.profile ?? "" const currentQdrantUrl = this.qdrantUrl ?? "" const currentQdrantApiKey = this.qdrantApiKey ?? "" @@ -337,6 +363,10 @@ export class CodeIndexConfigManager { return true } + if (prevBedrockRegion !== currentBedrockRegion || prevBedrockProfile !== currentBedrockProfile) { + return true + } + // Check for model dimension changes (generic for all providers) if (prevModelDimension !== currentModelDimension) { return true @@ -395,6 +425,7 @@ export class CodeIndexConfigManager { geminiOptions: this.geminiOptions, mistralOptions: this.mistralOptions, vercelAiGatewayOptions: this.vercelAiGatewayOptions, + bedrockOptions: this.bedrockOptions, qdrantUrl: this.qdrantUrl, qdrantApiKey: this.qdrantApiKey, searchMinScore: this.currentSearchMinScore, diff --git a/src/services/code-index/embedders/__tests__/bedrock.spec.ts b/src/services/code-index/embedders/__tests__/bedrock.spec.ts new file mode 100644 index 000000000000..b88ce786fdd5 --- /dev/null +++ b/src/services/code-index/embedders/__tests__/bedrock.spec.ts @@ -0,0 +1,521 @@ +import type { MockedFunction } from "vitest" +import { BedrockRuntimeClient, InvokeModelCommand } from "@aws-sdk/client-bedrock-runtime" + +import { BedrockEmbedder } from "../bedrock" +import { MAX_ITEM_TOKENS, INITIAL_RETRY_DELAY_MS } from "../../constants" + +// Mock the AWS SDK +vitest.mock("@aws-sdk/client-bedrock-runtime", () => { + return { + BedrockRuntimeClient: vitest.fn().mockImplementation(() => ({ + send: vitest.fn(), + })), + InvokeModelCommand: vitest.fn().mockImplementation((input) => ({ + input, + })), + } +}) +vitest.mock("@aws-sdk/credential-providers", () => ({ + fromEnv: vitest.fn().mockReturnValue(Promise.resolve({})), + fromIni: vitest.fn().mockReturnValue(Promise.resolve({})), +})) + +// Mock TelemetryService +vitest.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureEvent: vitest.fn(), + }, + }, +})) + +// Mock i18n +vitest.mock("../../../../i18n", () => ({ + t: (key: string, params?: Record) => { + const translations: Record = { + "embeddings:authenticationFailed": + "Failed to create embeddings: Authentication failed. Please check your AWS credentials.", + "embeddings:failedWithStatus": `Failed to create embeddings after ${params?.attempts} attempts: HTTP ${params?.statusCode} - ${params?.errorMessage}`, + "embeddings:failedWithError": `Failed to create embeddings after ${params?.attempts} attempts: ${params?.errorMessage}`, + "embeddings:failedMaxAttempts": `Failed to create embeddings after ${params?.attempts} attempts`, + "embeddings:textExceedsTokenLimit": `Text at index ${params?.index} exceeds maximum token limit (${params?.itemTokens} > ${params?.maxTokens}). Skipping.`, + "embeddings:rateLimitRetry": `Rate limit hit, retrying in ${params?.delayMs}ms (attempt ${params?.attempt}/${params?.maxRetries})`, + "embeddings:bedrock.invalidResponseFormat": "Invalid response format from Bedrock", + "embeddings:bedrock.invalidCredentials": "Invalid AWS credentials", + "embeddings:bedrock.accessDenied": "Access denied to Bedrock service", + "embeddings:bedrock.modelNotFound": `Model ${params?.model} not found`, + "embeddings:validation.authenticationFailed": "Authentication failed", + "embeddings:validation.connectionFailed": "Connection failed", + "embeddings:validation.serviceUnavailable": "Service unavailable", + "embeddings:validation.configurationError": "Configuration error", + } + return translations[key] || key + }, +})) + +// Mock console methods +const consoleMocks = { + error: vitest.spyOn(console, "error").mockImplementation(() => {}), + warn: vitest.spyOn(console, "warn").mockImplementation(() => {}), +} + +describe("BedrockEmbedder", () => { + let embedder: BedrockEmbedder + let mockSend: MockedFunction + + beforeEach(() => { + vitest.clearAllMocks() + consoleMocks.error.mockClear() + consoleMocks.warn.mockClear() + + mockSend = vitest.fn() + + // Set up the mock implementation + const MockedBedrockRuntimeClient = BedrockRuntimeClient as any + MockedBedrockRuntimeClient.mockImplementation(() => ({ + send: mockSend, + })) + + embedder = new BedrockEmbedder("us-east-1", "amazon.titan-embed-text-v2:0") + }) + + afterEach(() => { + vitest.clearAllMocks() + }) + + describe("constructor", () => { + it("should initialize with provided region and model", () => { + expect(embedder.embedderInfo.name).toBe("bedrock") + }) + + it("should use default region if not provided", () => { + const defaultEmbedder = new BedrockEmbedder() + expect(defaultEmbedder).toBeDefined() + }) + + it("should use profile if provided", () => { + const profileEmbedder = new BedrockEmbedder("us-west-2", undefined, "dev-profile") + expect(profileEmbedder).toBeDefined() + }) + }) + + describe("createEmbeddings", () => { + const testModelId = "amazon.titan-embed-text-v2:0" + + it("should create embeddings for a single text with Titan model", async () => { + const testTexts = ["Hello world"] + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 2, + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + const result = await embedder.createEmbeddings(testTexts) + + expect(mockSend).toHaveBeenCalled() + const command = mockSend.mock.calls[0][0] as any + expect(command.input.modelId).toBe(testModelId) + const bodyStr = + typeof command.input.body === "string" + ? command.input.body + : new TextDecoder().decode(command.input.body as Uint8Array) + expect(JSON.parse(bodyStr || "{}")).toEqual({ + inputText: "Hello world", + }) + + expect(result).toEqual({ + embeddings: [[0.1, 0.2, 0.3]], + usage: { promptTokens: 2, totalTokens: 2 }, + }) + }) + + it("should create embeddings for multiple texts", async () => { + const testTexts = ["Hello world", "Another text"] + const mockResponses = [ + { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 2, + }), + ), + }, + { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.4, 0.5, 0.6], + inputTextTokenCount: 3, + }), + ), + }, + ] + + mockSend.mockResolvedValueOnce(mockResponses[0]).mockResolvedValueOnce(mockResponses[1]) + + const result = await embedder.createEmbeddings(testTexts) + + expect(mockSend).toHaveBeenCalledTimes(2) + expect(result).toEqual({ + embeddings: [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ], + usage: { promptTokens: 5, totalTokens: 5 }, + }) + }) + + it("should handle Cohere model format", async () => { + const cohereEmbedder = new BedrockEmbedder("us-east-1", "cohere.embed-english-v3") + const testTexts = ["Hello world"] + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + embeddings: [[0.1, 0.2, 0.3]], + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + const result = await cohereEmbedder.createEmbeddings(testTexts) + + const command = mockSend.mock.calls[0][0] as InvokeModelCommand + const bodyStr = + typeof command.input.body === "string" + ? command.input.body + : new TextDecoder().decode(command.input.body as Uint8Array) + expect(JSON.parse(bodyStr || "{}")).toEqual({ + texts: ["Hello world"], + input_type: "search_document", + }) + + expect(result).toEqual({ + embeddings: [[0.1, 0.2, 0.3]], + usage: { promptTokens: 0, totalTokens: 0 }, + }) + }) + + it("should use custom model when provided", async () => { + const testTexts = ["Hello world"] + const customModel = "amazon.titan-embed-text-v1" + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 2, + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + await embedder.createEmbeddings(testTexts, customModel) + + const command = mockSend.mock.calls[0][0] as InvokeModelCommand + expect(command.input.modelId).toBe(customModel) + }) + + it("should handle missing token count data gracefully", async () => { + const testTexts = ["Hello world"] + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + const result = await embedder.createEmbeddings(testTexts) + + expect(result).toEqual({ + embeddings: [[0.1, 0.2, 0.3]], + usage: { promptTokens: 0, totalTokens: 0 }, + }) + }) + + /** + * Test batching logic when texts exceed token limits + */ + describe("batching logic", () => { + it("should warn and skip texts exceeding maximum token limit", async () => { + // Create a text that exceeds MAX_ITEM_TOKENS (4 characters ≈ 1 token) + const oversizedText = "a".repeat(MAX_ITEM_TOKENS * 4 + 100) + const normalText = "normal text" + const testTexts = [normalText, oversizedText, "another normal"] + + const mockResponses = [ + { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 3, + }), + ), + }, + { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.4, 0.5, 0.6], + inputTextTokenCount: 3, + }), + ), + }, + ] + + mockSend.mockResolvedValueOnce(mockResponses[0]).mockResolvedValueOnce(mockResponses[1]) + + const result = await embedder.createEmbeddings(testTexts) + + // Verify warning was logged + expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("exceeds maximum token limit")) + + // Verify only normal texts were processed + expect(mockSend).toHaveBeenCalledTimes(2) + expect(result.embeddings).toHaveLength(2) + }) + + it("should handle all texts being skipped due to size", async () => { + const oversizedText = "a".repeat(MAX_ITEM_TOKENS * 4 + 100) + const testTexts = [oversizedText, oversizedText] + + const result = await embedder.createEmbeddings(testTexts) + + expect(console.warn).toHaveBeenCalledTimes(2) + expect(mockSend).not.toHaveBeenCalled() + expect(result).toEqual({ + embeddings: [], + usage: { promptTokens: 0, totalTokens: 0 }, + }) + }) + }) + + /** + * Test retry logic for rate limiting and other errors + */ + describe("retry logic", () => { + beforeEach(() => { + vitest.useFakeTimers() + }) + + afterEach(() => { + vitest.useRealTimers() + }) + + it("should retry on throttling errors with exponential backoff", async () => { + const testTexts = ["Hello world"] + const throttlingError = new Error("Rate limit exceeded") + throttlingError.name = "ThrottlingException" + + mockSend + .mockRejectedValueOnce(throttlingError) + .mockRejectedValueOnce(throttlingError) + .mockResolvedValueOnce({ + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 2, + }), + ), + }) + + const resultPromise = embedder.createEmbeddings(testTexts) + + // Fast-forward through the delays + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS) // First retry delay + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 2) // Second retry delay + + const result = await resultPromise + + expect(mockSend).toHaveBeenCalledTimes(3) + expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Rate limit hit, retrying in")) + expect(result).toEqual({ + embeddings: [[0.1, 0.2, 0.3]], + usage: { promptTokens: 2, totalTokens: 2 }, + }) + }) + + it("should not retry on non-throttling errors", async () => { + const testTexts = ["Hello world"] + const authError = new Error("Unauthorized") + authError.name = "UnrecognizedClientException" + + mockSend.mockRejectedValue(authError) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + "Failed to create embeddings after 3 attempts: Unauthorized", + ) + + expect(mockSend).toHaveBeenCalledTimes(1) + expect(console.warn).not.toHaveBeenCalledWith(expect.stringContaining("Rate limit hit")) + }) + }) + + /** + * Test error handling scenarios + */ + describe("error handling", () => { + it("should handle API errors gracefully", async () => { + const testTexts = ["Hello world"] + const apiError = new Error("API connection failed") + + mockSend.mockRejectedValue(apiError) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + "Failed to create embeddings after 3 attempts: API connection failed", + ) + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining("Bedrock embedder error"), + expect.any(Error), + ) + }) + + it("should handle empty text arrays", async () => { + const testTexts: string[] = [] + + const result = await embedder.createEmbeddings(testTexts) + + expect(result).toEqual({ + embeddings: [], + usage: { promptTokens: 0, totalTokens: 0 }, + }) + expect(mockSend).not.toHaveBeenCalled() + }) + + it("should handle malformed API responses", async () => { + const testTexts = ["Hello world"] + const malformedResponse = { + body: new TextEncoder().encode("not json"), + } + + mockSend.mockResolvedValue(malformedResponse) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow() + }) + + it("should handle AWS-specific errors", async () => { + const testTexts = ["Hello world"] + + // Test UnrecognizedClientException + const authError = new Error("Invalid credentials") + authError.name = "UnrecognizedClientException" + mockSend.mockRejectedValueOnce(authError) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + "Failed to create embeddings after 3 attempts: Invalid credentials", + ) + + // Test AccessDeniedException + const accessError = new Error("Access denied") + accessError.name = "AccessDeniedException" + mockSend.mockRejectedValueOnce(accessError) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + "Failed to create embeddings after 3 attempts: Access denied", + ) + + // Test ResourceNotFoundException + const notFoundError = new Error("Model not found") + notFoundError.name = "ResourceNotFoundException" + mockSend.mockRejectedValueOnce(notFoundError) + + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + "Failed to create embeddings after 3 attempts: Model not found", + ) + }) + }) + }) + + describe("validateConfiguration", () => { + it("should validate successfully with valid configuration", async () => { + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + embedding: [0.1, 0.2, 0.3], + inputTextTokenCount: 1, + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(true) + expect(result.error).toBeUndefined() + expect(mockSend).toHaveBeenCalled() + }) + + it("should fail validation with authentication error", async () => { + const authError = new Error("Invalid credentials") + authError.name = "UnrecognizedClientException" + mockSend.mockRejectedValue(authError) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toBe("Invalid AWS credentials") + }) + + it("should fail validation with access denied error", async () => { + const accessError = new Error("Access denied") + accessError.name = "AccessDeniedException" + mockSend.mockRejectedValue(accessError) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toBe("Access denied to Bedrock service") + }) + + it("should fail validation with model not found error", async () => { + const notFoundError = new Error("Model not found") + notFoundError.name = "ResourceNotFoundException" + mockSend.mockRejectedValue(notFoundError) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toContain("not found") + }) + + it("should fail validation with invalid response", async () => { + const mockResponse = { + body: new TextEncoder().encode( + JSON.stringify({ + // Missing embedding field + inputTextTokenCount: 1, + }), + ), + } + mockSend.mockResolvedValue(mockResponse) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toBe("Invalid response format from Bedrock") + }) + + it("should fail validation with connection error", async () => { + const connectionError = new Error("ECONNREFUSED") + mockSend.mockRejectedValue(connectionError) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toBe("Connection failed") + }) + + it("should fail validation with generic error", async () => { + const genericError = new Error("Unknown error") + mockSend.mockRejectedValue(genericError) + + const result = await embedder.validateConfiguration() + + expect(result.valid).toBe(false) + expect(result.error).toBe("Configuration error") + }) + }) +}) diff --git a/src/services/code-index/embedders/bedrock.ts b/src/services/code-index/embedders/bedrock.ts new file mode 100644 index 000000000000..9bdaf6022648 --- /dev/null +++ b/src/services/code-index/embedders/bedrock.ts @@ -0,0 +1,294 @@ +import { BedrockRuntimeClient, InvokeModelCommand, InvokeModelCommandInput } from "@aws-sdk/client-bedrock-runtime" +import { fromEnv, fromIni } from "@aws-sdk/credential-providers" +import { IEmbedder, EmbeddingResponse, EmbedderInfo } from "../interfaces" +import { + MAX_BATCH_TOKENS, + MAX_ITEM_TOKENS, + MAX_BATCH_RETRIES as MAX_RETRIES, + INITIAL_RETRY_DELAY_MS as INITIAL_DELAY_MS, +} from "../constants" +import { getDefaultModelId } from "../../../shared/embeddingModels" +import { t } from "../../../i18n" +import { withValidationErrorHandling, formatEmbeddingError, HttpError } from "../shared/validation-helpers" +import { TelemetryEventName } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" + +/** + * AWS Bedrock implementation of the embedder interface with batching and rate limiting + */ +export class BedrockEmbedder implements IEmbedder { + private bedrockClient: BedrockRuntimeClient + private readonly defaultModelId: string + + /** + * Creates a new AWS Bedrock embedder + * @param region AWS region for Bedrock service + * @param modelId Optional model ID override + * @param profile Optional AWS profile name for credentials + */ + constructor( + private readonly region: string = "us-east-1", + modelId?: string, + private readonly profile?: string, + ) { + // Initialize the Bedrock client with appropriate credentials + const credentials = this.profile ? fromIni({ profile: this.profile }) : fromEnv() + + this.bedrockClient = new BedrockRuntimeClient({ + region: this.region, + credentials, + }) + + this.defaultModelId = modelId || getDefaultModelId("bedrock") + } + + /** + * Creates embeddings for the given texts with batching and rate limiting + * @param texts Array of text strings to embed + * @param model Optional model identifier + * @returns Promise resolving to embedding response + */ + async createEmbeddings(texts: string[], model?: string): Promise { + const modelToUse = model || this.defaultModelId + + const allEmbeddings: number[][] = [] + const usage = { promptTokens: 0, totalTokens: 0 } + const remainingTexts = [...texts] + + while (remainingTexts.length > 0) { + const currentBatch: string[] = [] + let currentBatchTokens = 0 + const processedIndices: number[] = [] + + for (let i = 0; i < remainingTexts.length; i++) { + const text = remainingTexts[i] + const itemTokens = Math.ceil(text.length / 4) + + if (itemTokens > MAX_ITEM_TOKENS) { + console.warn( + t("embeddings:textExceedsTokenLimit", { + index: i, + itemTokens, + maxTokens: MAX_ITEM_TOKENS, + }), + ) + processedIndices.push(i) + continue + } + + if (currentBatchTokens + itemTokens <= MAX_BATCH_TOKENS) { + currentBatch.push(text) + currentBatchTokens += itemTokens + processedIndices.push(i) + } else { + break + } + } + + // Remove processed items from remainingTexts (in reverse order to maintain correct indices) + for (let i = processedIndices.length - 1; i >= 0; i--) { + remainingTexts.splice(processedIndices[i], 1) + } + + if (currentBatch.length > 0) { + const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse) + allEmbeddings.push(...batchResult.embeddings) + usage.promptTokens += batchResult.usage.promptTokens + usage.totalTokens += batchResult.usage.totalTokens + } + } + + return { embeddings: allEmbeddings, usage } + } + + /** + * Helper method to handle batch embedding with retries and exponential backoff + * @param batchTexts Array of texts to embed in this batch + * @param model Model identifier to use + * @returns Promise resolving to embeddings and usage statistics + */ + private async _embedBatchWithRetries( + batchTexts: string[], + model: string, + ): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> { + for (let attempts = 0; attempts < MAX_RETRIES; attempts++) { + try { + const embeddings: number[][] = [] + let totalPromptTokens = 0 + let totalTokens = 0 + + // Process each text in the batch + // Note: Amazon Titan models typically don't support batch embedding in a single request + // So we process them individually + for (const text of batchTexts) { + const embedding = await this._invokeEmbeddingModel(text, model) + embeddings.push(embedding.embedding) + totalPromptTokens += embedding.inputTextTokenCount || 0 + totalTokens += embedding.inputTextTokenCount || 0 + } + + return { + embeddings, + usage: { + promptTokens: totalPromptTokens, + totalTokens, + }, + } + } catch (error: any) { + const hasMoreAttempts = attempts < MAX_RETRIES - 1 + + // Check if it's a rate limit error + if (error.name === "ThrottlingException" && hasMoreAttempts) { + const delayMs = INITIAL_DELAY_MS * Math.pow(2, attempts) + console.warn( + t("embeddings:rateLimitRetry", { + delayMs, + attempt: attempts + 1, + maxRetries: MAX_RETRIES, + }), + ) + await new Promise((resolve) => setTimeout(resolve, delayMs)) + continue + } + + // Capture telemetry before reformatting the error + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + location: "BedrockEmbedder:_embedBatchWithRetries", + attempt: attempts + 1, + }) + + // Log the error for debugging + console.error(`Bedrock embedder error (attempt ${attempts + 1}/${MAX_RETRIES}):`, error) + + // Format and throw the error + throw formatEmbeddingError(error, MAX_RETRIES) + } + } + + throw new Error(t("embeddings:failedMaxAttempts", { attempts: MAX_RETRIES })) + } + + /** + * Invokes the embedding model for a single text + * @param text The text to embed + * @param model The model identifier to use + * @returns Promise resolving to embedding and token count + */ + private async _invokeEmbeddingModel( + text: string, + model: string, + ): Promise<{ embedding: number[]; inputTextTokenCount?: number }> { + let requestBody: any + let modelId = model + + // Prepare the request body based on the model + if (model.startsWith("amazon.titan-embed")) { + requestBody = { + inputText: text, + } + } else if (model.startsWith("cohere.embed")) { + requestBody = { + texts: [text], + input_type: "search_document", // or "search_query" depending on use case + } + } else { + // Default to Titan format + requestBody = { + inputText: text, + } + } + + const params: InvokeModelCommandInput = { + modelId, + body: JSON.stringify(requestBody), + contentType: "application/json", + accept: "application/json", + } + + const command = new InvokeModelCommand(params) + const response = await this.bedrockClient.send(command) + + // Parse the response + const responseBody = JSON.parse(new TextDecoder().decode(response.body)) + + // Extract embedding based on model type + if (model.startsWith("amazon.titan-embed")) { + return { + embedding: responseBody.embedding, + inputTextTokenCount: responseBody.inputTextTokenCount, + } + } else if (model.startsWith("cohere.embed")) { + return { + embedding: responseBody.embeddings[0], + // Cohere doesn't provide token count in response + } + } else { + // Default to Titan format + return { + embedding: responseBody.embedding, + inputTextTokenCount: responseBody.inputTextTokenCount, + } + } + } + + /** + * Validates the Bedrock embedder configuration by attempting a minimal embedding request + * @returns Promise resolving to validation result with success status and optional error message + */ + async validateConfiguration(): Promise<{ valid: boolean; error?: string }> { + return withValidationErrorHandling(async () => { + try { + // Test with a minimal embedding request + const result = await this._invokeEmbeddingModel("test", this.defaultModelId) + + // Check if we got a valid response + if (!result.embedding || result.embedding.length === 0) { + return { + valid: false, + error: t("embeddings:bedrock.invalidResponseFormat"), + } + } + + return { valid: true } + } catch (error: any) { + // Check for specific AWS errors + if (error.name === "UnrecognizedClientException") { + return { + valid: false, + error: t("embeddings:bedrock.invalidCredentials"), + } + } + + if (error.name === "AccessDeniedException") { + return { + valid: false, + error: t("embeddings:bedrock.accessDenied"), + } + } + + if (error.name === "ResourceNotFoundException") { + return { + valid: false, + error: t("embeddings:bedrock.modelNotFound", { model: this.defaultModelId }), + } + } + + // Capture telemetry for validation errors + TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + location: "BedrockEmbedder:validateConfiguration", + }) + throw error + } + }, "bedrock") + } + + get embedderInfo(): EmbedderInfo { + return { + name: "bedrock", + } + } +} diff --git a/src/services/code-index/interfaces/config.ts b/src/services/code-index/interfaces/config.ts index f168e268691a..df29b9af2417 100644 --- a/src/services/code-index/interfaces/config.ts +++ b/src/services/code-index/interfaces/config.ts @@ -15,6 +15,7 @@ export interface CodeIndexConfig { geminiOptions?: { apiKey: string } mistralOptions?: { apiKey: string } vercelAiGatewayOptions?: { apiKey: string } + bedrockOptions?: { region: string; profile?: string } qdrantUrl?: string qdrantApiKey?: string searchMinScore?: number @@ -37,6 +38,8 @@ export type PreviousConfigSnapshot = { geminiApiKey?: string mistralApiKey?: string vercelAiGatewayApiKey?: string + bedrockRegion?: string + bedrockProfile?: string qdrantUrl?: string qdrantApiKey?: string } diff --git a/src/services/code-index/interfaces/embedder.ts b/src/services/code-index/interfaces/embedder.ts index 1fcda3aca32d..3f0a3d5d8a09 100644 --- a/src/services/code-index/interfaces/embedder.ts +++ b/src/services/code-index/interfaces/embedder.ts @@ -28,7 +28,14 @@ export interface EmbeddingResponse { } } -export type AvailableEmbedders = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vercel-ai-gateway" +export type AvailableEmbedders = + | "openai" + | "ollama" + | "openai-compatible" + | "gemini" + | "mistral" + | "vercel-ai-gateway" + | "bedrock" export interface EmbedderInfo { name: AvailableEmbedders diff --git a/src/services/code-index/interfaces/manager.ts b/src/services/code-index/interfaces/manager.ts index 527900f6d1c7..9d2493d78880 100644 --- a/src/services/code-index/interfaces/manager.ts +++ b/src/services/code-index/interfaces/manager.ts @@ -70,7 +70,14 @@ export interface ICodeIndexManager { } export type IndexingState = "Standby" | "Indexing" | "Indexed" | "Error" -export type EmbedderProvider = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vercel-ai-gateway" +export type EmbedderProvider = + | "openai" + | "ollama" + | "openai-compatible" + | "gemini" + | "mistral" + | "vercel-ai-gateway" + | "bedrock" export interface IndexProgressUpdate { systemStatus: IndexingState diff --git a/src/services/code-index/service-factory.ts b/src/services/code-index/service-factory.ts index 6d69e1f0b6c6..54e54520f24a 100644 --- a/src/services/code-index/service-factory.ts +++ b/src/services/code-index/service-factory.ts @@ -5,6 +5,7 @@ import { OpenAICompatibleEmbedder } from "./embedders/openai-compatible" import { GeminiEmbedder } from "./embedders/gemini" import { MistralEmbedder } from "./embedders/mistral" import { VercelAiGatewayEmbedder } from "./embedders/vercel-ai-gateway" +import { BedrockEmbedder } from "./embedders/bedrock" import { EmbedderProvider, getDefaultModelId, getModelDimension } from "../../shared/embeddingModels" import { QdrantVectorStore } from "./vector-store/qdrant-client" import { codeParser, DirectoryScanner, FileWatcher } from "./processors" @@ -79,6 +80,11 @@ export class CodeIndexServiceFactory { throw new Error(t("embeddings:serviceFactory.vercelAiGatewayConfigMissing")) } return new VercelAiGatewayEmbedder(config.vercelAiGatewayOptions.apiKey, config.modelId) + } else if (provider === "bedrock") { + if (!config.bedrockOptions?.region) { + throw new Error(t("embeddings:serviceFactory.bedrockConfigMissing")) + } + return new BedrockEmbedder(config.bedrockOptions.region, config.modelId, config.bedrockOptions.profile) } throw new Error( diff --git a/src/shared/embeddingModels.ts b/src/shared/embeddingModels.ts index 80c51a6b4558..063d70120f4b 100644 --- a/src/shared/embeddingModels.ts +++ b/src/shared/embeddingModels.ts @@ -2,7 +2,14 @@ * Defines profiles for different embedding models, including their dimensions. */ -export type EmbedderProvider = "openai" | "ollama" | "openai-compatible" | "gemini" | "mistral" | "vercel-ai-gateway" // Add other providers as needed +export type EmbedderProvider = + | "openai" + | "ollama" + | "openai-compatible" + | "gemini" + | "mistral" + | "vercel-ai-gateway" + | "bedrock" // Add other providers as needed export interface EmbeddingModelProfile { dimension: number @@ -70,6 +77,15 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = { "mistral/codestral-embed": { dimension: 1536, scoreThreshold: 0.4 }, "mistral/mistral-embed": { dimension: 1024, scoreThreshold: 0.4 }, }, + bedrock: { + // Amazon Titan Embed models + "amazon.titan-embed-text-v1": { dimension: 1536, scoreThreshold: 0.4 }, + "amazon.titan-embed-text-v2:0": { dimension: 1024, scoreThreshold: 0.4 }, + "amazon.titan-embed-image-v1": { dimension: 1024, scoreThreshold: 0.4 }, + // Cohere models available through Bedrock + "cohere.embed-english-v3": { dimension: 1024, scoreThreshold: 0.4 }, + "cohere.embed-multilingual-v3": { dimension: 1024, scoreThreshold: 0.4 }, + }, } /** @@ -163,6 +179,9 @@ export function getDefaultModelId(provider: EmbedderProvider): string { case "vercel-ai-gateway": return "openai/text-embedding-3-large" + case "bedrock": + return "amazon.titan-embed-text-v2:0" + default: // Fallback for unknown providers console.warn(`Unknown provider for default model ID: ${provider}. Falling back to OpenAI default.`)