diff --git a/convex/lib/embeddings.test.ts b/convex/lib/embeddings.test.ts new file mode 100644 index 0000000000..c36cc1d463 --- /dev/null +++ b/convex/lib/embeddings.test.ts @@ -0,0 +1,95 @@ +/* @vitest-environment node */ + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { EMBEDDING_DIMENSIONS, generateEmbedding } from './embeddings' + +const fetchMock = vi.fn() +const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + +const originalFetch = globalThis.fetch +const originalApiKey = process.env.OPENAI_API_KEY + +function jsonResponse(payload: unknown, init?: ResponseInit) { + return new Response(JSON.stringify(payload), { + status: 200, + headers: { + 'content-type': 'application/json', + }, + ...init, + }) +} + +beforeEach(() => { + fetchMock.mockReset() + globalThis.fetch = fetchMock as typeof fetch + process.env.OPENAI_API_KEY = 'test-key' + consoleWarnSpy.mockClear() +}) + +afterEach(() => { + globalThis.fetch = originalFetch + + if (originalApiKey === undefined) { + delete process.env.OPENAI_API_KEY + } else { + process.env.OPENAI_API_KEY = originalApiKey + } + + vi.useRealTimers() +}) + +describe('generateEmbedding', () => { + it('returns zero embedding when OPENAI_API_KEY is missing', async () => { + delete process.env.OPENAI_API_KEY + const result = await generateEmbedding('hello world') + + expect(result).toHaveLength(EMBEDDING_DIMENSIONS) + expect(result.every((value) => value === 0)).toBe(true) + expect(fetchMock).not.toHaveBeenCalled() + }) + + it('retries on 429 responses and then succeeds', async () => { + vi.useFakeTimers() + fetchMock.mockResolvedValueOnce(new Response('rate limited', { status: 429 })) + fetchMock.mockResolvedValueOnce(jsonResponse({ data: [{ embedding: [0.25, 0.75] }] })) + + const promise = generateEmbedding('retry me') + await vi.runAllTimersAsync() + + await expect(promise).resolves.toEqual([0.25, 0.75]) + expect(fetchMock).toHaveBeenCalledTimes(2) + }) + + it('does not retry non-retryable 4xx responses', async () => { + fetchMock.mockResolvedValueOnce(new Response('bad request', { status: 400 })) + + await expect(generateEmbedding('bad')).rejects.toThrow('Embedding failed: bad request') + expect(fetchMock).toHaveBeenCalledTimes(1) + }) + + it('retries on network failures and then succeeds', async () => { + vi.useFakeTimers() + fetchMock.mockRejectedValueOnce(new TypeError('fetch failed')) + fetchMock.mockResolvedValueOnce(jsonResponse({ data: [{ embedding: [1, 2, 3] }] })) + + const promise = generateEmbedding('network retry') + await vi.runAllTimersAsync() + + await expect(promise).resolves.toEqual([1, 2, 3]) + expect(fetchMock).toHaveBeenCalledTimes(2) + }) + + it('retries timeouts up to max attempts and preserves timeout error', async () => { + vi.useFakeTimers() + fetchMock.mockRejectedValue(new DOMException('aborted', 'AbortError')) + + const promise = generateEmbedding('always timeout') + const rejection = expect(promise).rejects.toThrow( + 'OpenAI API request timed out after 10 seconds', + ) + await vi.runAllTimersAsync() + + await rejection + expect(fetchMock).toHaveBeenCalledTimes(3) + }) +}) diff --git a/convex/lib/embeddings.ts b/convex/lib/embeddings.ts index 1a5ee84270..082aca6db2 100644 --- a/convex/lib/embeddings.ts +++ b/convex/lib/embeddings.ts @@ -1,10 +1,67 @@ export const EMBEDDING_MODEL = 'text-embedding-3-small' export const EMBEDDING_DIMENSIONS = 1536 +const EMBEDDING_ENDPOINT = 'https://api.openai.com/v1/embeddings' +const REQUEST_TIMEOUT_MS = 10_000 +const MAX_ATTEMPTS = 3 +const BASE_RETRY_DELAY_MS = 1_000 + +class RetryableEmbeddingError extends Error { + constructor(message: string, options?: { cause?: unknown }) { + super(message, options) + this.name = 'RetryableEmbeddingError' + } +} + function emptyEmbedding() { return Array.from({ length: EMBEDDING_DIMENSIONS }, () => 0) } +function parseRetryAfterMs(retryAfterHeader: string | null) { + if (!retryAfterHeader) return null + + const seconds = Number(retryAfterHeader) + if (Number.isFinite(seconds) && seconds >= 0) { + return Math.round(seconds * 1000) + } + + const dateMs = Date.parse(retryAfterHeader) + if (Number.isFinite(dateMs)) { + return Math.max(0, dateMs - Date.now()) + } + + return null +} + +function getRetryDelayMs(attempt: number, retryAfterMs: number | null) { + const exponentialDelayMs = BASE_RETRY_DELAY_MS * 2 ** attempt + if (retryAfterMs == null) return exponentialDelayMs + return Math.max(exponentialDelayMs, retryAfterMs) +} + +function normalizeRetryableNetworkError(error: unknown) { + if (!(error instanceof Error)) return null + + if (error.name === 'AbortError') { + return new RetryableEmbeddingError( + `OpenAI API request timed out after ${Math.floor(REQUEST_TIMEOUT_MS / 1000)} seconds`, + { cause: error }, + ) + } + + if (error instanceof TypeError) { + return new RetryableEmbeddingError(`Embedding request failed: ${error.message}`, { cause: error }) + } + + return null +} + +function sleep(ms: number) { + return new Promise((resolve) => { + setTimeout(resolve, ms) + }) +} + export async function generateEmbedding(text: string) { const apiKey = process.env.OPENAI_API_KEY if (!apiKey) { @@ -12,40 +69,77 @@ export async function generateEmbedding(text: string) { return emptyEmbedding() } - const controller = new AbortController() - const timeoutId = setTimeout(() => controller.abort(), 10000) // 10 second timeout - - try { - const response = await fetch('https://api.openai.com/v1/embeddings', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, - }, - body: JSON.stringify({ - model: EMBEDDING_MODEL, - input: text, - }), - signal: controller.signal, - }) - - if (!response.ok) { - const message = await response.text() - throw new Error(`Embedding failed: ${message}`) - } + let lastRetryableError: RetryableEmbeddingError | null = null - const payload = (await response.json()) as { - data?: Array<{ embedding: number[] }> - } - const embedding = payload.data?.[0]?.embedding - if (!embedding) throw new Error('Embedding missing from response') - return embedding - } catch (error) { - if (error instanceof Error && error.name === 'AbortError') { - throw new Error('OpenAI API request timed out after 10 seconds', { cause: error }) + for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), REQUEST_TIMEOUT_MS) + + try { + const response = await fetch(EMBEDDING_ENDPOINT, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify({ + model: EMBEDDING_MODEL, + input: text, + }), + signal: controller.signal, + }) + + if (!response.ok) { + const message = await response.text() + const isRetryableStatus = response.status === 429 || response.status >= 500 + if (isRetryableStatus) { + const retryableError = new RetryableEmbeddingError( + `Embedding failed (${response.status}): ${message}`, + ) + lastRetryableError = retryableError + + if (attempt < MAX_ATTEMPTS - 1) { + const retryAfterMs = parseRetryAfterMs(response.headers.get('retry-after')) + const delayMs = getRetryDelayMs(attempt, retryAfterMs) + console.warn( + `OpenAI embeddings retry in ${delayMs}ms (attempt ${attempt + 1}/${MAX_ATTEMPTS})`, + ) + await sleep(delayMs) + continue + } + + throw retryableError + } + + throw new Error(`Embedding failed: ${message}`) + } + + const payload = (await response.json()) as { + data?: Array<{ embedding: number[] }> + } + const embedding = payload.data?.[0]?.embedding + if (!embedding) throw new Error('Embedding missing from response') + return embedding + } catch (error) { + const retryableNetworkError = normalizeRetryableNetworkError(error) + if (retryableNetworkError) { + lastRetryableError = retryableNetworkError + if (attempt < MAX_ATTEMPTS - 1) { + const delayMs = getRetryDelayMs(attempt, null) + console.warn( + `OpenAI embeddings network retry in ${delayMs}ms (attempt ${attempt + 1}/${MAX_ATTEMPTS})`, + ) + await sleep(delayMs) + continue + } + throw retryableNetworkError + } + + throw error + } finally { + clearTimeout(timeoutId) } - throw error - } finally { - clearTimeout(timeoutId) } + + throw lastRetryableError ?? new Error('Embedding failed after retries') }