Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions src/api/providers/fetchers/__tests__/litellm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import axios from "axios"
import { getLiteLLMModels } from "../litellm"
import { COMPUTER_USE_MODELS } from "../../../../shared/api"

// Mock axios
jest.mock("axios")
const mockedAxios = axios as jest.Mocked<typeof axios>

const DUMMY_INVALID_KEY = "invalid-key-for-testing"

describe("getLiteLLMModels", () => {
beforeEach(() => {
jest.clearAllMocks()
})

it("successfully fetches and formats LiteLLM models", async () => {
const mockResponse = {
data: {
data: [
{
model_name: "claude-3-5-sonnet",
model_info: {
max_tokens: 4096,
max_input_tokens: 200000,
supports_vision: true,
supports_prompt_caching: false,
input_cost_per_token: 0.000003,
output_cost_per_token: 0.000015,
},
litellm_params: {
model: "anthropic/claude-3.5-sonnet",
},
},
{
model_name: "gpt-4-turbo",
model_info: {
max_tokens: 8192,
max_input_tokens: 128000,
supports_vision: false,
supports_prompt_caching: false,
input_cost_per_token: 0.00001,
output_cost_per_token: 0.00003,
},
litellm_params: {
model: "openai/gpt-4-turbo",
},
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const result = await getLiteLLMModels("test-api-key", "http://localhost:4000")

expect(mockedAxios.get).toHaveBeenCalledWith("http://localhost:4000/v1/model/info", {
headers: {
Authorization: "Bearer test-api-key",
"Content-Type": "application/json",
},
timeout: 5000,
})

expect(result).toEqual({
"claude-3-5-sonnet": {
maxTokens: 4096,
contextWindow: 200000,
supportsImages: true,
supportsComputerUse: true,
supportsPromptCache: false,
inputPrice: 3,
outputPrice: 15,
description: "claude-3-5-sonnet via LiteLLM proxy",
},
"gpt-4-turbo": {
maxTokens: 8192,
contextWindow: 128000,
supportsImages: false,
supportsComputerUse: false,
supportsPromptCache: false,
inputPrice: 10,
outputPrice: 30,
description: "gpt-4-turbo via LiteLLM proxy",
},
})
})

it("makes request without authorization header when no API key provided", async () => {
const mockResponse = {
data: {
data: [],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

await getLiteLLMModels("", "http://localhost:4000")

expect(mockedAxios.get).toHaveBeenCalledWith("http://localhost:4000/v1/model/info", {
headers: {
"Content-Type": "application/json",
},
timeout: 5000,
})
})

it("handles computer use models correctly", async () => {
const computerUseModel = Array.from(COMPUTER_USE_MODELS)[0]
const mockResponse = {
data: {
data: [
{
model_name: "test-computer-model",
model_info: {
max_tokens: 4096,
max_input_tokens: 200000,
supports_vision: true,
},
litellm_params: {
model: `anthropic/${computerUseModel}`,
},
},
],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const result = await getLiteLLMModels("test-api-key", "http://localhost:4000")

expect(result["test-computer-model"]).toEqual({
maxTokens: 4096,
contextWindow: 200000,
supportsImages: true,
supportsComputerUse: true,
supportsPromptCache: false,
inputPrice: undefined,
outputPrice: undefined,
description: "test-computer-model via LiteLLM proxy",
})
})

it("throws error for unexpected response format", async () => {
const mockResponse = {
data: {
// Missing 'data' field
models: [],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

await expect(getLiteLLMModels("test-api-key", "http://localhost:4000")).rejects.toThrow(
"Failed to fetch LiteLLM models: Unexpected response format.",
)
})

it("throws detailed error for HTTP error responses", async () => {
const axiosError = {
response: {
status: 401,
statusText: "Unauthorized",
},
isAxiosError: true,
}

mockedAxios.isAxiosError.mockReturnValue(true)
mockedAxios.get.mockRejectedValue(axiosError)

await expect(getLiteLLMModels(DUMMY_INVALID_KEY, "http://localhost:4000")).rejects.toThrow(
"Failed to fetch LiteLLM models: 401 Unauthorized. Check base URL and API key.",
)
})

it("throws network error for request failures", async () => {
const axiosError = {
request: {},
isAxiosError: true,
}

mockedAxios.isAxiosError.mockReturnValue(true)
mockedAxios.get.mockRejectedValue(axiosError)

await expect(getLiteLLMModels("test-api-key", "http://invalid-url")).rejects.toThrow(
"Failed to fetch LiteLLM models: No response from server. Check LiteLLM server status and base URL.",
)
})

it("throws generic error for other failures", async () => {
const genericError = new Error("Network timeout")

mockedAxios.isAxiosError.mockReturnValue(false)
mockedAxios.get.mockRejectedValue(genericError)

await expect(getLiteLLMModels("test-api-key", "http://localhost:4000")).rejects.toThrow(
"Failed to fetch LiteLLM models: Network timeout",
)
})

it("handles timeout parameter correctly", async () => {
const mockResponse = { data: { data: [] } }
mockedAxios.get.mockResolvedValue(mockResponse)

await getLiteLLMModels("test-api-key", "http://localhost:4000")

expect(mockedAxios.get).toHaveBeenCalledWith(
"http://localhost:4000/v1/model/info",
expect.objectContaining({
timeout: 5000,
}),
)
})

it("returns empty object when data array is empty", async () => {
const mockResponse = {
data: {
data: [],
},
}

mockedAxios.get.mockResolvedValue(mockResponse)

const result = await getLiteLLMModels("test-api-key", "http://localhost:4000")

expect(result).toEqual({})
})
})
158 changes: 158 additions & 0 deletions src/api/providers/fetchers/__tests__/modelCache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import { getModels } from "../modelCache"
import { getLiteLLMModels } from "../litellm"
import { getOpenRouterModels } from "../openrouter"
import { getRequestyModels } from "../requesty"
import { getGlamaModels } from "../glama"
import { getUnboundModels } from "../unbound"

// Mock NodeCache to avoid cache interference
jest.mock("node-cache", () => {
return jest.fn().mockImplementation(() => ({
get: jest.fn().mockReturnValue(undefined), // Always return cache miss
set: jest.fn(),
del: jest.fn(),
}))
})

// Mock fs/promises to avoid file system operations
jest.mock("fs/promises", () => ({
writeFile: jest.fn().mockResolvedValue(undefined),
readFile: jest.fn().mockResolvedValue("{}"),
mkdir: jest.fn().mockResolvedValue(undefined),
}))

// Mock all the model fetchers
jest.mock("../litellm")
jest.mock("../openrouter")
jest.mock("../requesty")
jest.mock("../glama")
jest.mock("../unbound")

const mockGetLiteLLMModels = getLiteLLMModels as jest.MockedFunction<typeof getLiteLLMModels>
const mockGetOpenRouterModels = getOpenRouterModels as jest.MockedFunction<typeof getOpenRouterModels>
const mockGetRequestyModels = getRequestyModels as jest.MockedFunction<typeof getRequestyModels>
const mockGetGlamaModels = getGlamaModels as jest.MockedFunction<typeof getGlamaModels>
const mockGetUnboundModels = getUnboundModels as jest.MockedFunction<typeof getUnboundModels>

const DUMMY_REQUESTY_KEY = "requesty-key-for-testing"
const DUMMY_UNBOUND_KEY = "unbound-key-for-testing"

describe("getModels with new GetModelsOptions", () => {
beforeEach(() => {
jest.clearAllMocks()
})

it("calls getLiteLLMModels with correct parameters", async () => {
const mockModels = {
"claude-3-sonnet": {
maxTokens: 4096,
contextWindow: 200000,
supportsPromptCache: false,
description: "Claude 3 Sonnet via LiteLLM",
},
}
mockGetLiteLLMModels.mockResolvedValue(mockModels)

const result = await getModels({
provider: "litellm",
apiKey: "test-api-key",
baseUrl: "http://localhost:4000",
})

expect(mockGetLiteLLMModels).toHaveBeenCalledWith("test-api-key", "http://localhost:4000")
expect(result).toEqual(mockModels)
})

it("calls getOpenRouterModels for openrouter provider", async () => {
const mockModels = {
"openrouter/model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: false,
description: "OpenRouter model",
},
}
mockGetOpenRouterModels.mockResolvedValue(mockModels)

const result = await getModels({ provider: "openrouter" })

expect(mockGetOpenRouterModels).toHaveBeenCalled()
expect(result).toEqual(mockModels)
})

it("calls getRequestyModels with optional API key", async () => {
const mockModels = {
"requesty/model": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Requesty model",
},
}
mockGetRequestyModels.mockResolvedValue(mockModels)

const result = await getModels({ provider: "requesty", apiKey: DUMMY_REQUESTY_KEY })

expect(mockGetRequestyModels).toHaveBeenCalledWith(DUMMY_REQUESTY_KEY)
expect(result).toEqual(mockModels)
})

it("calls getGlamaModels for glama provider", async () => {
const mockModels = {
"glama/model": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Glama model",
},
}
mockGetGlamaModels.mockResolvedValue(mockModels)

const result = await getModels({ provider: "glama" })

expect(mockGetGlamaModels).toHaveBeenCalled()
expect(result).toEqual(mockModels)
})

it("calls getUnboundModels with optional API key", async () => {
const mockModels = {
"unbound/model": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Unbound model",
},
}
mockGetUnboundModels.mockResolvedValue(mockModels)

const result = await getModels({ provider: "unbound", apiKey: DUMMY_UNBOUND_KEY })

expect(mockGetUnboundModels).toHaveBeenCalledWith(DUMMY_UNBOUND_KEY)
expect(result).toEqual(mockModels)
})

it("handles errors and re-throws them", async () => {
const expectedError = new Error("LiteLLM connection failed")
mockGetLiteLLMModels.mockRejectedValue(expectedError)

await expect(
getModels({
provider: "litellm",
apiKey: "test-api-key",
baseUrl: "http://localhost:4000",
}),
).rejects.toThrow("LiteLLM connection failed")
})

it("validates exhaustive provider checking with unknown provider", async () => {
// This test ensures TypeScript catches unknown providers at compile time
// In practice, the discriminated union should prevent this at compile time
const unknownProvider = "unknown" as any

await expect(
getModels({
provider: unknownProvider,
}),
).rejects.toThrow("Unknown provider: unknown")
})
})
Loading