Skip to content

Commit 4591b1b

Browse files
committed
Implement LiteLLM provider configuration
1 parent ea93cea commit 4591b1b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1542
-261
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
// npx jest src/api/providers/__tests__/litellm.test.ts
2+
3+
import { Anthropic } from "@anthropic-ai/sdk" // For message types
4+
import OpenAI from "openai"
5+
6+
import { LiteLLMHandler } from "../litellm"
7+
import { ApiHandlerOptions, litellmDefaultModelId, litellmDefaultModelInfo, ModelInfo } from "../../../shared/api"
8+
import * as modelCache from "../fetchers/modelCache"
9+
10+
const mockOpenAICreateCompletions = jest.fn()
11+
jest.mock("openai", () => {
12+
return jest.fn(() => ({
13+
chat: {
14+
completions: {
15+
create: mockOpenAICreateCompletions,
16+
},
17+
},
18+
}))
19+
})
20+
21+
jest.mock("../fetchers/modelCache", () => ({
22+
getModels: jest.fn(),
23+
}))
24+
25+
const mockGetModels = modelCache.getModels as jest.Mock
26+
27+
describe("LiteLLMHandler", () => {
28+
const defaultMockOptions: ApiHandlerOptions = {
29+
litellmApiKey: "test-litellm-key",
30+
litellmModelId: "litellm-test-model",
31+
litellmBaseUrl: "http://mock-litellm-server:8000",
32+
modelTemperature: 0.1, // Add a default temperature for tests
33+
}
34+
35+
const mockModelInfo: ModelInfo = {
36+
maxTokens: 4096,
37+
contextWindow: 128000,
38+
supportsImages: false,
39+
supportsPromptCache: true,
40+
supportsComputerUse: false,
41+
description: "A test LiteLLM model",
42+
}
43+
44+
beforeEach(() => {
45+
jest.clearAllMocks()
46+
47+
mockGetModels.mockResolvedValue({
48+
[defaultMockOptions.litellmModelId!]: mockModelInfo,
49+
})
50+
// Spy on supportsTemperature and default to true for most tests, can be overridden
51+
jest.spyOn(LiteLLMHandler.prototype as any, "supportsTemperature").mockReturnValue(true)
52+
})
53+
54+
describe("constructor", () => {
55+
it("initializes with correct options and defaults", () => {
56+
const handler = new LiteLLMHandler(defaultMockOptions) // This will call new OpenAI()
57+
expect(handler).toBeInstanceOf(LiteLLMHandler)
58+
// Check if the mock constructor was called with the right params
59+
expect(OpenAI).toHaveBeenCalledWith({
60+
baseURL: defaultMockOptions.litellmBaseUrl,
61+
apiKey: defaultMockOptions.litellmApiKey,
62+
})
63+
})
64+
65+
it("uses default baseURL if not provided", () => {
66+
new LiteLLMHandler({ litellmApiKey: "key", litellmModelId: "id" })
67+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "http://localhost:4000" }))
68+
})
69+
70+
it("uses dummy API key if not provided", () => {
71+
new LiteLLMHandler({ litellmBaseUrl: "url", litellmModelId: "id" })
72+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: "sk-1234" }))
73+
})
74+
})
75+
76+
describe("fetchModel", () => {
77+
it("returns correct model info when modelId is provided and found in getModels", async () => {
78+
const handler = new LiteLLMHandler(defaultMockOptions)
79+
const result = await handler.fetchModel()
80+
expect(mockGetModels).toHaveBeenCalledWith({
81+
provider: "litellm",
82+
apiKey: defaultMockOptions.litellmApiKey,
83+
baseUrl: defaultMockOptions.litellmBaseUrl,
84+
})
85+
expect(result).toEqual({ id: defaultMockOptions.litellmModelId, info: mockModelInfo })
86+
})
87+
88+
it("returns defaultModelInfo if provided modelId is NOT found in getModels result", async () => {
89+
mockGetModels.mockResolvedValueOnce({ "another-model": { contextWindow: 1, supportsPromptCache: false } })
90+
const handler = new LiteLLMHandler(defaultMockOptions)
91+
const result = await handler.fetchModel()
92+
expect(result.id).toBe(litellmDefaultModelId)
93+
expect(result.info).toEqual(litellmDefaultModelInfo)
94+
})
95+
96+
it("uses defaultModelId and its info if litellmModelId option is undefined and defaultModelId is in getModels", async () => {
97+
const specificDefaultModelInfo = { ...mockModelInfo, description: "Specific Default Model Info" }
98+
mockGetModels.mockResolvedValueOnce({ [litellmDefaultModelId]: specificDefaultModelInfo })
99+
const handler = new LiteLLMHandler({ ...defaultMockOptions, litellmModelId: undefined })
100+
const result = await handler.fetchModel()
101+
expect(result.id).toBe(litellmDefaultModelId)
102+
expect(result.info).toEqual(specificDefaultModelInfo)
103+
})
104+
105+
it("uses defaultModelId and defaultModelInfo if litellmModelId option is undefined and defaultModelId is NOT in getModels", async () => {
106+
mockGetModels.mockResolvedValueOnce({ "some-other-model": mockModelInfo })
107+
const handler = new LiteLLMHandler({ ...defaultMockOptions, litellmModelId: undefined })
108+
const result = await handler.fetchModel()
109+
expect(result.id).toBe(litellmDefaultModelId)
110+
expect(result.info).toEqual(litellmDefaultModelInfo)
111+
})
112+
113+
it("throws an error if getModels fails", async () => {
114+
mockGetModels.mockRejectedValueOnce(new Error("Network error"))
115+
const handler = new LiteLLMHandler(defaultMockOptions)
116+
await expect(handler.fetchModel()).rejects.toThrow("Network error")
117+
})
118+
})
119+
120+
describe("createMessage", () => {
121+
const systemPrompt = "You are a helpful assistant."
122+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
123+
// mockCreateGlobal is no longer needed here, use mockOpenAICreateCompletions directly
124+
125+
beforeEach(() => {
126+
// mockOpenAICreateCompletions is already cleared by jest.clearAllMocks() in the outer beforeEach
127+
// or mockOpenAICreateCompletions.mockClear() if we want to be very specific
128+
})
129+
130+
it("streams text and usage chunks correctly", async () => {
131+
const mockStreamData = {
132+
async *[Symbol.asyncIterator]() {
133+
yield { id: "chunk1", choices: [{ delta: { content: "Response part 1" } }], usage: null }
134+
yield { id: "chunk2", choices: [{ delta: { content: " part 2" } }], usage: null }
135+
yield { id: "chunk3", choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } }
136+
},
137+
}
138+
mockOpenAICreateCompletions.mockReturnValue({
139+
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
140+
})
141+
142+
const handler = new LiteLLMHandler(defaultMockOptions)
143+
const generator = handler.createMessage(systemPrompt, messages)
144+
const chunks = []
145+
for await (const chunk of generator) {
146+
chunks.push(chunk)
147+
}
148+
149+
expect(chunks).toEqual([
150+
{ type: "text", text: "Response part 1" },
151+
{ type: "text", text: " part 2" },
152+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
153+
])
154+
expect(mockOpenAICreateCompletions).toHaveBeenCalledWith({
155+
model: defaultMockOptions.litellmModelId,
156+
max_tokens: mockModelInfo.maxTokens,
157+
messages: [
158+
{ role: "system", content: systemPrompt },
159+
{ role: "user", content: "Hello" },
160+
],
161+
stream: true,
162+
stream_options: { include_usage: true },
163+
temperature: defaultMockOptions.modelTemperature,
164+
})
165+
})
166+
167+
it("handles temperature option if supported", async () => {
168+
const handler = new LiteLLMHandler({ ...defaultMockOptions, modelTemperature: 0.7 })
169+
const mockStreamData = { async *[Symbol.asyncIterator]() {} }
170+
mockOpenAICreateCompletions.mockReturnValue({
171+
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
172+
})
173+
174+
const generator = handler.createMessage(systemPrompt, messages)
175+
for await (const _ of generator) {
176+
}
177+
178+
expect(mockOpenAICreateCompletions).toHaveBeenCalledWith(expect.objectContaining({ temperature: 0.7 }))
179+
})
180+
181+
it("does not include temperature if not supported by model", async () => {
182+
;(LiteLLMHandler.prototype as any).supportsTemperature.mockReturnValue(false)
183+
const handler = new LiteLLMHandler(defaultMockOptions)
184+
const mockStreamData = { async *[Symbol.asyncIterator]() {} }
185+
mockOpenAICreateCompletions.mockReturnValue({
186+
withResponse: jest.fn().mockResolvedValue({ data: mockStreamData }),
187+
})
188+
189+
const generator = handler.createMessage(systemPrompt, messages)
190+
for await (const _ of generator) {
191+
}
192+
193+
const callArgs = mockOpenAICreateCompletions.mock.calls[0][0]
194+
expect(callArgs.temperature).toBeUndefined()
195+
})
196+
197+
it("throws a formatted error if API call (streaming) fails", async () => {
198+
const apiError = new Error("LLM Provider Error")
199+
// Simulate the error occurring within the stream itself
200+
mockOpenAICreateCompletions.mockReturnValue({
201+
withResponse: jest.fn().mockResolvedValue({
202+
data: {
203+
async *[Symbol.asyncIterator]() {
204+
throw apiError
205+
},
206+
},
207+
}),
208+
})
209+
210+
const handler = new LiteLLMHandler(defaultMockOptions)
211+
const generator = handler.createMessage(systemPrompt, messages)
212+
await expect(async () => {
213+
for await (const _ of generator) {
214+
}
215+
}).rejects.toThrow("LiteLLM streaming error: " + apiError.message)
216+
})
217+
})
218+
219+
describe("completePrompt", () => {
220+
const prompt = "Translate 'hello' to French."
221+
// mockCreateGlobal is no longer needed here, use mockOpenAICreateCompletions directly
222+
223+
beforeEach(() => {
224+
// mockOpenAICreateCompletions is already cleared by jest.clearAllMocks() in the outer beforeEach
225+
})
226+
227+
it("returns completion successfully", async () => {
228+
mockOpenAICreateCompletions.mockResolvedValueOnce({ choices: [{ message: { content: "Bonjour" } }] })
229+
const handler = new LiteLLMHandler(defaultMockOptions)
230+
const result = await handler.completePrompt(prompt)
231+
232+
expect(result).toBe("Bonjour")
233+
expect(mockOpenAICreateCompletions).toHaveBeenCalledWith({
234+
model: defaultMockOptions.litellmModelId,
235+
max_tokens: mockModelInfo.maxTokens,
236+
messages: [{ role: "user", content: prompt }],
237+
temperature: defaultMockOptions.modelTemperature,
238+
})
239+
})
240+
241+
it("throws a formatted error if API call fails", async () => {
242+
mockOpenAICreateCompletions.mockRejectedValueOnce(new Error("Completion API Down"))
243+
const handler = new LiteLLMHandler(defaultMockOptions)
244+
await expect(handler.completePrompt(prompt)).rejects.toThrow(
245+
"LiteLLM completion error: Completion API Down",
246+
)
247+
})
248+
})
249+
})

src/api/providers/fetchers/litellm.ts

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api"
77
* @param apiKey The API key for the LiteLLM server
88
* @param baseUrl The base URL of the LiteLLM server
99
* @returns A promise that resolves to a record of model IDs to model info
10+
* @throws Will throw an error if the request fails or the response is not as expected.
1011
*/
1112
export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise<ModelRecord> {
1213
try {
@@ -18,7 +19,8 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
1819
headers["Authorization"] = `Bearer ${apiKey}`
1920
}
2021

21-
const response = await axios.get(`${baseUrl}/v1/model/info`, { headers })
22+
// Added timeout to prevent indefinite hanging
23+
const response = await axios.get(`${baseUrl}/v1/model/info`, { headers, timeout: 15000 })
2224
const models: ModelRecord = {}
2325

2426
const computerModels = Array.from(COMPUTER_USE_MODELS)
@@ -32,11 +34,17 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
3234

3335
if (!modelName || !modelInfo || !litellmModelName) continue
3436

37+
let determinedMaxTokens = modelInfo.max_tokens || modelInfo.max_output_tokens || 8192
38+
39+
if (modelName.includes("claude-3-7-sonnet")) {
40+
// due to https://github.com/BerriAI/litellm/issues/8984 until proper extended thinking support is added
41+
determinedMaxTokens = 64000
42+
}
43+
3544
models[modelName] = {
36-
maxTokens: modelInfo.max_tokens || 8192,
45+
maxTokens: determinedMaxTokens,
3746
contextWindow: modelInfo.max_input_tokens || 200000,
3847
supportsImages: Boolean(modelInfo.supports_vision),
39-
// litellm_params.model may have a prefix like openrouter/
4048
supportsComputerUse: computerModels.some((computer_model) =>
4149
litellmModelName.endsWith(computer_model),
4250
),
@@ -48,11 +56,25 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
4856
description: `${modelName} via LiteLLM proxy`,
4957
}
5058
}
59+
} else {
60+
// If response.data.data is not in the expected format, consider it an error.
61+
console.error("Error fetching LiteLLM models: Unexpected response format", response.data)
62+
throw new Error("Failed to fetch LiteLLM models: Unexpected response format.")
5163
}
5264

5365
return models
54-
} catch (error) {
55-
console.error("Error fetching LiteLLM models:", error)
56-
return {}
66+
} catch (error: any) {
67+
console.error("Error fetching LiteLLM models:", error.message ? error.message : error)
68+
if (axios.isAxiosError(error) && error.response) {
69+
throw new Error(
70+
`Failed to fetch LiteLLM models: ${error.response.status} ${error.response.statusText}. Check base URL and API key.`,
71+
)
72+
} else if (axios.isAxiosError(error) && error.request) {
73+
throw new Error(
74+
"Failed to fetch LiteLLM models: No response from server. Check LiteLLM server status and base URL.",
75+
)
76+
} else {
77+
throw new Error(`Failed to fetch LiteLLM models: ${error.message || "An unknown error occurred."}`)
78+
}
5779
}
5880
}

0 commit comments

Comments
 (0)