Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import { describe, it, expect, vi, beforeEach, afterEach, MockedClass, MockedFunction } from "vitest"
import { OpenAI } from "openai"
import { OpenAICompatibleEmbedder } from "../openai-compatible"

// Mock the OpenAI SDK
vi.mock("openai")

// Mock TelemetryService
vi.mock("@roo-code/telemetry", () => ({
TelemetryService: {
instance: {
captureEvent: vi.fn(),
},
},
}))

// Mock i18n
vi.mock("../../../../i18n", () => ({
t: (key: string, params?: Record<string, any>) => {
const translations: Record<string, string> = {
"embeddings:rateLimitRetry": `Rate limit hit, retrying in ${params?.delayMs}ms (attempt ${params?.attempt}/${params?.maxRetries})`,
"embeddings:failedMaxAttempts": `Failed to create embeddings after ${params?.attempts} attempts`,
"embeddings:failedWithStatus": `Failed to create embeddings after ${params?.attempts} attempts: HTTP ${params?.statusCode} - ${params?.errorMessage}`,
"embeddings:failedWithError": `Failed to create embeddings after ${params?.attempts} attempts: ${params?.errorMessage}`,
}
return translations[key] || key
},
}))

const MockedOpenAI = OpenAI as MockedClass<typeof OpenAI>

describe("OpenAICompatibleEmbedder - Global Rate Limiting", () => {
let mockOpenAIInstance: any
let mockEmbeddingsCreate: MockedFunction<any>

const testBaseUrl = "https://api.openai.com/v1"
const testApiKey = "test-api-key"
const testModelId = "text-embedding-3-small"

beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
vi.spyOn(console, "warn").mockImplementation(() => {})
vi.spyOn(console, "error").mockImplementation(() => {})

// Setup mock OpenAI instance
mockEmbeddingsCreate = vi.fn()
mockOpenAIInstance = {
embeddings: {
create: mockEmbeddingsCreate,
},
}

MockedOpenAI.mockImplementation(() => mockOpenAIInstance)

// Reset global rate limit state
const embedder = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
;(embedder as any).constructor.globalRateLimitState = {
isRateLimited: false,
rateLimitResetTime: 0,
consecutiveRateLimitErrors: 0,
lastRateLimitError: 0,
mutex: (embedder as any).constructor.globalRateLimitState.mutex,
}
})

afterEach(() => {
vi.useRealTimers()
vi.restoreAllMocks()
})

it("should apply global rate limiting across multiple batch requests", async () => {
const embedder1 = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
const embedder2 = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)

// First batch hits rate limit
const rateLimitError = new Error("Rate limit exceeded") as any
rateLimitError.status = 429

mockEmbeddingsCreate
.mockRejectedValueOnce(rateLimitError) // First attempt fails
.mockResolvedValue({
data: [{ embedding: "base64encodeddata" }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})

// Start first batch request
const batch1Promise = embedder1.createEmbeddings(["test1"])

// Advance time slightly to let the first request fail and set global rate limit
await vi.advanceTimersByTimeAsync(100)

// Start second batch request while global rate limit is active
const batch2Promise = embedder2.createEmbeddings(["test2"])

// Check that global rate limit was set
const state = (embedder1 as any).constructor.globalRateLimitState
expect(state.isRateLimited).toBe(true)
expect(state.consecutiveRateLimitErrors).toBe(1)

// Advance time to complete rate limit delay (5 seconds base delay)
await vi.advanceTimersByTimeAsync(5000)

// Both requests should complete
const [result1, result2] = await Promise.all([batch1Promise, batch2Promise])

expect(result1.embeddings).toHaveLength(1)
expect(result2.embeddings).toHaveLength(1)

// The second embedder should have waited for the global rate limit
// No logging expected - we've removed it to prevent log flooding
})

it("should track consecutive rate limit errors", async () => {
const embedder = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
const state = (embedder as any).constructor.globalRateLimitState

const rateLimitError = new Error("Rate limit exceeded") as any
rateLimitError.status = 429

// Test that consecutive errors increment when they happen quickly
// Mock multiple rate limit errors in a single request
mockEmbeddingsCreate
.mockRejectedValueOnce(rateLimitError) // First attempt
.mockRejectedValueOnce(rateLimitError) // Retry 1
.mockResolvedValueOnce({
data: [{ embedding: "base64encodeddata" }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})

const promise1 = embedder.createEmbeddings(["test1"])

// Wait for first attempt to fail
await vi.advanceTimersByTimeAsync(100)
expect(state.consecutiveRateLimitErrors).toBe(1)

// Wait for first retry (500ms) to also fail
await vi.advanceTimersByTimeAsync(500)

// The state should show 2 consecutive errors now
// Note: The count might be 1 if the global rate limit kicked in before the second attempt
expect(state.consecutiveRateLimitErrors).toBeGreaterThanOrEqual(1)

// Wait for the global rate limit and successful retry
await vi.advanceTimersByTimeAsync(20000)
await promise1

// Verify the delay increases with consecutive errors
// Make another request immediately that also hits rate limit
mockEmbeddingsCreate.mockRejectedValueOnce(rateLimitError).mockResolvedValueOnce({
data: [{ embedding: "base64encodeddata" }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})

// Store the current consecutive count before the next request
const previousCount = state.consecutiveRateLimitErrors

const promise2 = embedder.createEmbeddings(["test2"])
await vi.advanceTimersByTimeAsync(100)

// Should have incremented from the previous count
expect(state.consecutiveRateLimitErrors).toBeGreaterThan(previousCount)

// Complete the second request
await vi.advanceTimersByTimeAsync(20000)
await promise2
})

it("should reset consecutive error count after time passes", async () => {
const embedder = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
const state = (embedder as any).constructor.globalRateLimitState

// Manually set state to simulate previous errors
state.consecutiveRateLimitErrors = 3
state.lastRateLimitError = Date.now() - 70000 // 70 seconds ago

const rateLimitError = new Error("Rate limit exceeded") as any
rateLimitError.status = 429

mockEmbeddingsCreate.mockRejectedValueOnce(rateLimitError).mockResolvedValueOnce({
data: [{ embedding: "base64encodeddata" }],
usage: { prompt_tokens: 10, total_tokens: 15 },
})

// Trigger the updateGlobalRateLimitState method
await (embedder as any).updateGlobalRateLimitState(rateLimitError)

// Should reset to 1 since more than 60 seconds passed
expect(state.consecutiveRateLimitErrors).toBe(1)
})

it("should not exceed maximum delay of 5 minutes", async () => {
const embedder = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
const state = (embedder as any).constructor.globalRateLimitState

// Set state to simulate many consecutive errors
state.consecutiveRateLimitErrors = 10 // This would normally result in a very long delay

const rateLimitError = new Error("Rate limit exceeded") as any
rateLimitError.status = 429

// Trigger the updateGlobalRateLimitState method
await (embedder as any).updateGlobalRateLimitState(rateLimitError)

// Calculate the expected delay
const now = Date.now()
const delay = state.rateLimitResetTime - now

// Should be capped at 5 minutes (300000ms)
expect(delay).toBeLessThanOrEqual(300000)
expect(delay).toBeGreaterThan(0)
})
})
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ describe("OpenAICompatibleEmbedder", () => {
}

MockedOpenAI.mockImplementation(() => mockOpenAIInstance)

// Reset global rate limit state to prevent interference between tests
const tempEmbedder = new OpenAICompatibleEmbedder(testBaseUrl, testApiKey, testModelId)
;(tempEmbedder as any).constructor.globalRateLimitState = {
isRateLimited: false,
rateLimitResetTime: 0,
consecutiveRateLimitErrors: 0,
lastRateLimitError: 0,
mutex: (tempEmbedder as any).constructor.globalRateLimitState.mutex,
}
})

afterEach(() => {
Expand Down Expand Up @@ -385,9 +395,17 @@ describe("OpenAICompatibleEmbedder", () => {

const resultPromise = embedder.createEmbeddings(testTexts)

// Fast-forward through the delays
await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS) // First retry delay
await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 2) // Second retry delay
// First attempt fails immediately, triggering global rate limit (5s)
await vitest.advanceTimersByTimeAsync(100)

// Wait for global rate limit delay
await vitest.advanceTimersByTimeAsync(5000)

// Second attempt also fails, increasing delay
await vitest.advanceTimersByTimeAsync(100)

// Wait for increased global rate limit delay (10s)
await vitest.advanceTimersByTimeAsync(10000)

const result = await resultPromise

Expand Down Expand Up @@ -445,7 +463,7 @@ describe("OpenAICompatibleEmbedder", () => {

expect(console.error).toHaveBeenCalledWith(
expect.stringContaining("OpenAI Compatible embedder error"),
expect.any(Error),
apiError,
)
})

Expand All @@ -461,7 +479,7 @@ describe("OpenAICompatibleEmbedder", () => {

expect(console.error).toHaveBeenCalledWith(
expect.stringContaining("OpenAI Compatible embedder error"),
batchError,
expect.any(Error),
)
})

Expand Down Expand Up @@ -791,10 +809,23 @@ describe("OpenAICompatibleEmbedder", () => {
)

const resultPromise = embedder.createEmbeddings(["test"])
await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 3)

// First attempt fails, triggering global rate limit
await vitest.advanceTimersByTimeAsync(100)

// Wait for global rate limit (5s)
await vitest.advanceTimersByTimeAsync(5000)

// Second attempt also fails
await vitest.advanceTimersByTimeAsync(100)

// Wait for increased global rate limit (10s)
await vitest.advanceTimersByTimeAsync(10000)

const result = await resultPromise

expect(global.fetch).toHaveBeenCalledTimes(3)
// Check that rate limit warnings were logged
expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Rate limit hit"))
expectEmbeddingValues(result.embeddings[0], [0.1, 0.2, 0.3])
vitest.useRealTimers()
Expand Down
Loading