Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions src/api/providers/__tests__/litellm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// npx jest src/api/providers/__tests__/litellm.test.ts

import { Anthropic } from "@anthropic-ai/sdk" // For message types
import OpenAI from "openai"

import { LiteLLMHandler } from "../litellm"
import { ApiHandlerOptions, litellmDefaultModelId, litellmDefaultModelInfo, ModelInfo } from "../../../shared/api"
import * as modelCache from "../fetchers/modelCache"

const mockOpenAICreateCompletions = jest.fn()
jest.mock("openai", () => {
return jest.fn(() => ({
chat: {
completions: {
create: mockOpenAICreateCompletions,
},
},
}))
})

jest.mock("../fetchers/modelCache", () => ({
getModels: jest.fn(),
}))

const mockGetModels = modelCache.getModels as jest.Mock

describe("LiteLLMHandler", () => {
const defaultMockOptions: ApiHandlerOptions = {
litellmApiKey: "test-litellm-key",
litellmModelId: "litellm-test-model",
litellmBaseUrl: "http://mock-litellm-server:8000",
modelTemperature: 0.1, // Add a default temperature for tests
}

const mockModelInfo: ModelInfo = {
maxTokens: 4096,
contextWindow: 128000,
supportsImages: false,
supportsPromptCache: true,
supportsComputerUse: false,
description: "A test LiteLLM model",
}

beforeEach(() => {
jest.clearAllMocks()

mockGetModels.mockResolvedValue({
[defaultMockOptions.litellmModelId!]: mockModelInfo,
})
// Spy on supportsTemperature and default to true for most tests, can be overridden
jest.spyOn(LiteLLMHandler.prototype as any, "supportsTemperature").mockReturnValue(true)
})

describe("constructor", () => {
it("initializes with correct options and defaults", () => {
const handler = new LiteLLMHandler(defaultMockOptions) // This will call new OpenAI()
expect(handler).toBeInstanceOf(LiteLLMHandler)
// Check if the mock constructor was called with the right params
expect(OpenAI).toHaveBeenCalledWith({
baseURL: defaultMockOptions.litellmBaseUrl,
apiKey: defaultMockOptions.litellmApiKey,
})
})

it("uses default baseURL if not provided", () => {
new LiteLLMHandler({ litellmApiKey: "key", litellmModelId: "id" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "http://localhost:4000" }))
})

it("uses dummy API key if not provided", () => {
new LiteLLMHandler({ litellmBaseUrl: "url", litellmModelId: "id" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: "sk-1234" }))
})
})

describe("fetchModel", () => {
it("returns correct model info when modelId is provided and found in getModels", async () => {
const handler = new LiteLLMHandler(defaultMockOptions)
const result = await handler.fetchModel()
expect(mockGetModels).toHaveBeenCalledWith({
provider: "litellm",
apiKey: defaultMockOptions.litellmApiKey,
baseUrl: defaultMockOptions.litellmBaseUrl,
})
expect(result).toEqual({ id: defaultMockOptions.litellmModelId, info: mockModelInfo })
})

it("returns defaultModelInfo if provided modelId is NOT found in getModels result", async () => {
mockGetModels.mockResolvedValueOnce({ "another-model": { contextWindow: 1, supportsPromptCache: false } })
const handler = new LiteLLMHandler(defaultMockOptions)
const result = await handler.fetchModel()
expect(result.id).toBe(litellmDefaultModelId)
expect(result.info).toEqual(litellmDefaultModelInfo)
})

it("uses defaultModelId and its info if litellmModelId option is undefined and defaultModelId is in getModels", async () => {
const specificDefaultModelInfo = { ...mockModelInfo, description: "Specific Default Model Info" }
mockGetModels.mockResolvedValueOnce({ [litellmDefaultModelId]: specificDefaultModelInfo })
const handler = new LiteLLMHandler({ ...defaultMockOptions, litellmModelId: undefined })
const result = await handler.fetchModel()
expect(result.id).toBe(litellmDefaultModelId)
expect(result.info).toEqual(specificDefaultModelInfo)
})

it("uses defaultModelId and defaultModelInfo if litellmModelId option is undefined and defaultModelId is NOT in getModels", async () => {
mockGetModels.mockResolvedValueOnce({ "some-other-model": mockModelInfo })
const handler = new LiteLLMHandler({ ...defaultMockOptions, litellmModelId: undefined })
const result = await handler.fetchModel()
expect(result.id).toBe(litellmDefaultModelId)
expect(result.info).toEqual(litellmDefaultModelInfo)
})

it("throws an error if getModels fails", async () => {
mockGetModels.mockRejectedValueOnce(new Error("Network error"))
const handler = new LiteLLMHandler(defaultMockOptions)
await expect(handler.fetchModel()).rejects.toThrow("Network error")
})
})

describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
// mockCreateGlobal is no longer needed here, use mockOpenAICreateCompletions directly

beforeEach(() => {
// mockOpenAICreateCompletions is already cleared by jest.clearAllMocks() in the outer beforeEach
// or mockOpenAICreateCompletions.mockClear() if we want to be very specific
})

it("streams text and usage chunks correctly", async () => {
const mockStreamData = {
async *[Symbol.asyncIterator]() {
yield { id: "chunk1", choices: [{ delta: { content: "Response part 1" } }], usage: null }
yield { id: "chunk2", choices: [{ delta: { content: " part 2" } }], usage: null }
yield { id: "chunk3", choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } }
},
}
mockOpenAICreateCompletions.mockReturnValue({
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
})

const handler = new LiteLLMHandler(defaultMockOptions)
const generator = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of generator) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Response part 1" },
{ type: "text", text: " part 2" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
expect(mockOpenAICreateCompletions).toHaveBeenCalledWith({
model: defaultMockOptions.litellmModelId,
max_tokens: mockModelInfo.maxTokens,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
temperature: defaultMockOptions.modelTemperature,
})
})

it("handles temperature option if supported", async () => {
const handler = new LiteLLMHandler({ ...defaultMockOptions, modelTemperature: 0.7 })
const mockStreamData = { async *[Symbol.asyncIterator]() {} }
mockOpenAICreateCompletions.mockReturnValue({
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const _ of generator) {
}

expect(mockOpenAICreateCompletions).toHaveBeenCalledWith(expect.objectContaining({ temperature: 0.7 }))
})

it("does not include temperature if not supported by model", async () => {
;(LiteLLMHandler.prototype as any).supportsTemperature.mockReturnValue(false)
const handler = new LiteLLMHandler(defaultMockOptions)
const mockStreamData = { async *[Symbol.asyncIterator]() {} }
mockOpenAICreateCompletions.mockReturnValue({
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const _ of generator) {
}

const callArgs = mockOpenAICreateCompletions.mock.calls[0][0]
expect(callArgs.temperature).toBeUndefined()
})

it("throws a formatted error if API call (streaming) fails", async () => {
const apiError = new Error("LLM Provider Error")
// Simulate the error occurring within the stream itself
mockOpenAICreateCompletions.mockReturnValue({
withResponse: jest.fn().mockResolvedValue({
data: {
async *[Symbol.asyncIterator]() {
throw apiError
},
},
}),
})

const handler = new LiteLLMHandler(defaultMockOptions)
const generator = handler.createMessage(systemPrompt, messages)
await expect(async () => {
for await (const _ of generator) {
}
}).rejects.toThrow("LiteLLM streaming error: " + apiError.message)
})
})

describe("completePrompt", () => {
const prompt = "Translate 'hello' to French."
// mockCreateGlobal is no longer needed here, use mockOpenAICreateCompletions directly

beforeEach(() => {
// mockOpenAICreateCompletions is already cleared by jest.clearAllMocks() in the outer beforeEach
})

it("returns completion successfully", async () => {
mockOpenAICreateCompletions.mockResolvedValueOnce({ choices: [{ message: { content: "Bonjour" } }] })
const handler = new LiteLLMHandler(defaultMockOptions)
const result = await handler.completePrompt(prompt)

expect(result).toBe("Bonjour")
expect(mockOpenAICreateCompletions).toHaveBeenCalledWith({
model: defaultMockOptions.litellmModelId,
max_tokens: mockModelInfo.maxTokens,
messages: [{ role: "user", content: prompt }],
temperature: defaultMockOptions.modelTemperature,
})
})

it("throws a formatted error if API call fails", async () => {
mockOpenAICreateCompletions.mockRejectedValueOnce(new Error("Completion API Down"))
const handler = new LiteLLMHandler(defaultMockOptions)
await expect(handler.completePrompt(prompt)).rejects.toThrow(
"LiteLLM completion error: Completion API Down",
)
})
})
})
34 changes: 28 additions & 6 deletions src/api/providers/fetchers/litellm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api"
* @param apiKey The API key for the LiteLLM server
* @param baseUrl The base URL of the LiteLLM server
* @returns A promise that resolves to a record of model IDs to model info
* @throws Will throw an error if the request fails or the response is not as expected.
*/
export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise<ModelRecord> {
try {
Expand All @@ -18,7 +19,8 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
headers["Authorization"] = `Bearer ${apiKey}`
}

const response = await axios.get(`${baseUrl}/v1/model/info`, { headers })
// Added timeout to prevent indefinite hanging
const response = await axios.get(`${baseUrl}/v1/model/info`, { headers, timeout: 15000 })
const models: ModelRecord = {}

const computerModels = Array.from(COMPUTER_USE_MODELS)
Expand All @@ -32,11 +34,17 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise

if (!modelName || !modelInfo || !litellmModelName) continue

let determinedMaxTokens = modelInfo.max_tokens || modelInfo.max_output_tokens || 8192
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other thing is that 64k or 128k output tokens is too high of a number in most cases for Sonnet - it eats up a ton of the available context window. In other providers we have a slider to choose the max tokens value.


if (modelName.includes("claude-3-7-sonnet")) {
// due to https://github.com/BerriAI/litellm/issues/8984 until proper extended thinking support is added
determinedMaxTokens = 64000
}
Comment on lines +39 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include the Anthropic header instead? https://docs.litellm.ai/docs/proxy/request_headers


models[modelName] = {
maxTokens: modelInfo.max_tokens || 8192,
maxTokens: determinedMaxTokens,
contextWindow: modelInfo.max_input_tokens || 200000,
supportsImages: Boolean(modelInfo.supports_vision),
// litellm_params.model may have a prefix like openrouter/
supportsComputerUse: computerModels.some((computer_model) =>
litellmModelName.endsWith(computer_model),
),
Expand All @@ -48,11 +56,25 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
description: `${modelName} via LiteLLM proxy`,
}
}
} else {
// If response.data.data is not in the expected format, consider it an error.
console.error("Error fetching LiteLLM models: Unexpected response format", response.data)
throw new Error("Failed to fetch LiteLLM models: Unexpected response format.")
}

return models
} catch (error) {
console.error("Error fetching LiteLLM models:", error)
return {}
} catch (error: any) {
console.error("Error fetching LiteLLM models:", error.message ? error.message : error)
if (axios.isAxiosError(error) && error.response) {
throw new Error(
`Failed to fetch LiteLLM models: ${error.response.status} ${error.response.statusText}. Check base URL and API key.`,
)
} else if (axios.isAxiosError(error) && error.request) {
throw new Error(
"Failed to fetch LiteLLM models: No response from server. Check LiteLLM server status and base URL.",
)
} else {
throw new Error(`Failed to fetch LiteLLM models: ${error.message || "An unknown error occurred."}`)
}
}
}
Loading