diff --git a/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts b/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts index ff757b86c7..6eaa147a39 100644 --- a/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts +++ b/src/services/code-index/embedders/__tests__/openai-compatible.spec.ts @@ -367,7 +367,7 @@ describe("OpenAICompatibleEmbedder", () => { vitest.useRealTimers() }) - it("should retry on rate limit errors with exponential backoff", async () => { + it("should retry on rate limit errors with exponential backoff and jitter", async () => { const testTexts = ["Hello world"] const rateLimitError = { status: 429, message: "Rate limit exceeded" } @@ -385,9 +385,9 @@ describe("OpenAICompatibleEmbedder", () => { const resultPromise = embedder.createEmbeddings(testTexts) - // Fast-forward through the delays - await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS) // First retry delay - await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 2) // Second retry delay + // Fast-forward through the delays (with max jitter) + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 1.2) // First retry delay with max jitter + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 2 * 1.2) // Second retry delay with max jitter const result = await resultPromise @@ -399,6 +399,45 @@ describe("OpenAICompatibleEmbedder", () => { }) }) + it("should retry on other transient errors (500, 502, 503, 504)", async () => { + const testTexts = ["Hello world"] + const transientErrors = [ + { status: 500, message: "Internal Server Error" }, + { status: 502, message: "Bad Gateway" }, + { status: 503, message: "Service Unavailable" }, + { status: 504, message: "Gateway Timeout" }, + ] + + for (const error of transientErrors) { + vitest.clearAllMocks() + + // Create base64 encoded embedding for successful response + const testEmbedding = new Float32Array([0.25, 0.5, 0.75]) + const base64String = Buffer.from(testEmbedding.buffer).toString("base64") + + mockEmbeddingsCreate.mockRejectedValueOnce(error).mockResolvedValueOnce({ + data: [{ embedding: base64String }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) + + const resultPromise = embedder.createEmbeddings(testTexts) + + // Fast-forward through the delay with max jitter + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 1.2) + + const result = await resultPromise + + expect(mockEmbeddingsCreate).toHaveBeenCalledTimes(2) + expect(console.warn).toHaveBeenCalledWith( + expect.stringContaining(`Error ${error.status} hit, retrying in`), + ) + expect(result).toEqual({ + embeddings: [[0.25, 0.5, 0.75]], + usage: { promptTokens: 10, totalTokens: 15 }, + }) + } + }) + it("should not retry on non-rate-limit errors", async () => { const testTexts = ["Hello world"] const authError = new Error("Unauthorized") @@ -416,16 +455,26 @@ describe("OpenAICompatibleEmbedder", () => { it("should throw error immediately on non-retryable errors", async () => { const testTexts = ["Hello world"] - const serverError = new Error("Internal server error") - ;(serverError as any).status = 500 + const nonRetryableErrors = [ + { status: 400, message: "Bad Request" }, + { status: 403, message: "Forbidden" }, + { status: 404, message: "Not Found" }, + ] - mockEmbeddingsCreate.mockRejectedValue(serverError) + for (const error of nonRetryableErrors) { + vitest.clearAllMocks() + const testError = new Error(error.message) + ;(testError as any).status = error.status - await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( - "Failed to create embeddings after 3 attempts: HTTP 500 - Internal server error", - ) + mockEmbeddingsCreate.mockRejectedValue(testError) - expect(mockEmbeddingsCreate).toHaveBeenCalledTimes(1) + await expect(embedder.createEmbeddings(testTexts)).rejects.toThrow( + `Failed to create embeddings after 3 attempts: HTTP ${error.status} - ${error.message}`, + ) + + expect(mockEmbeddingsCreate).toHaveBeenCalledTimes(1) + expect(console.warn).not.toHaveBeenCalledWith(expect.stringContaining("hit, retrying in")) + } }) }) @@ -775,7 +824,7 @@ describe("OpenAICompatibleEmbedder", () => { await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(expectedMessage) }) - it("should handle rate limiting with retries", async () => { + it("should handle rate limiting with retries and jitter", async () => { vitest.useFakeTimers() const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) const base64String = createBase64Embedding([0.1, 0.2, 0.3]) @@ -791,7 +840,10 @@ describe("OpenAICompatibleEmbedder", () => { ) const resultPromise = embedder.createEmbeddings(["test"]) - await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 3) + // Account for max jitter (20%) + await vitest.advanceTimersByTimeAsync( + INITIAL_RETRY_DELAY_MS * 1.2 + INITIAL_RETRY_DELAY_MS * 2 * 1.2, + ) const result = await resultPromise expect(global.fetch).toHaveBeenCalledTimes(3) @@ -800,6 +852,30 @@ describe("OpenAICompatibleEmbedder", () => { vitest.useRealTimers() }) + it("should handle other transient errors with retries", async () => { + vitest.useFakeTimers() + const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) + const base64String = createBase64Embedding([0.1, 0.2, 0.3]) + + ;(global.fetch as MockedFunction) + .mockResolvedValueOnce(createMockResponse({}, 503, false) as any) + .mockResolvedValueOnce( + createMockResponse({ + data: [{ embedding: base64String }], + usage: { prompt_tokens: 10, total_tokens: 15 }, + }) as any, + ) + + const resultPromise = embedder.createEmbeddings(["test"]) + await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 1.2) + const result = await resultPromise + + expect(global.fetch).toHaveBeenCalledTimes(2) + expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Error 503 hit")) + expectEmbeddingValues(result.embeddings[0], [0.1, 0.2, 0.3]) + vitest.useRealTimers() + }) + it("should handle multiple embeddings and network errors", async () => { const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId) diff --git a/src/services/code-index/embedders/openai-compatible.ts b/src/services/code-index/embedders/openai-compatible.ts index d882e78313..f1e594fcaf 100644 --- a/src/services/code-index/embedders/openai-compatible.ts +++ b/src/services/code-index/embedders/openai-compatible.ts @@ -296,17 +296,20 @@ export class OpenAICompatibleEmbedder implements IEmbedder { const hasMoreAttempts = attempts < MAX_RETRIES - 1 - // Check if it's a rate limit error + // Check if it's a retryable error const httpError = error as HttpError - if (httpError?.status === 429 && hasMoreAttempts) { - const delayMs = INITIAL_DELAY_MS * Math.pow(2, attempts) - console.warn( - t("embeddings:rateLimitRetry", { - delayMs, - attempt: attempts + 1, - maxRetries: MAX_RETRIES, - }), - ) + const isRetryableError = this.isRetryableError(httpError) + + if (isRetryableError && hasMoreAttempts) { + // Calculate exponential backoff with jitter + const baseDelay = INITIAL_DELAY_MS * Math.pow(2, attempts) + // Add jitter: random value between 0% and 20% of base delay + const jitter = Math.random() * 0.2 * baseDelay + const delayMs = Math.floor(baseDelay + jitter) + + const errorType = + httpError?.status === 429 ? "Rate limit" : `Error ${httpError?.status || "unknown"}` + console.warn(`${errorType} hit, retrying in ${delayMs}ms (attempt ${attempts + 1}/${MAX_RETRIES})`) await new Promise((resolve) => setTimeout(resolve, delayMs)) continue } @@ -368,6 +371,27 @@ export class OpenAICompatibleEmbedder implements IEmbedder { }, "openai-compatible") } + /** + * Determines if an error is retryable based on HTTP status code + * @param error The error to check + * @returns true if the error is retryable, false otherwise + */ + private isRetryableError(error: HttpError | any): boolean { + if (!error || typeof error.status !== "number") { + return false + } + + const retryableStatuses = [ + 429, // Too Many Requests (rate limit) + 500, // Internal Server Error + 502, // Bad Gateway + 503, // Service Unavailable + 504, // Gateway Timeout + ] + + return retryableStatuses.includes(error.status) + } + /** * Returns information about this embedder */