Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 149 additions & 1 deletion src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// npx vitest run api/providers/__tests__/openai.spec.ts

import { OpenAiHandler } from "../openai"
import { OpenAiHandler, getOpenAiModels } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { openAiModelInfoSaneDefaults } from "@roo-code/types"
import { Package } from "../../../shared/package"
import axios from "axios"

const mockCreate = vitest.fn()

Expand Down Expand Up @@ -68,6 +69,13 @@ vitest.mock("openai", () => {
}
})

// Mock axios for getOpenAiModels tests
vitest.mock("axios", () => ({
default: {
get: vitest.fn(),
},
}))

describe("OpenAiHandler", () => {
let handler: OpenAiHandler
let mockOptions: ApiHandlerOptions
Expand Down Expand Up @@ -776,3 +784,143 @@ describe("OpenAiHandler", () => {
})
})
})

describe("getOpenAiModels", () => {
beforeEach(() => {
vi.mocked(axios.get).mockClear()
})

it("should return empty array when baseUrl is not provided", async () => {
const result = await getOpenAiModels(undefined, "test-key")
expect(result).toEqual([])
expect(axios.get).not.toHaveBeenCalled()
})

it("should return empty array when baseUrl is empty string", async () => {
const result = await getOpenAiModels("", "test-key")
expect(result).toEqual([])
expect(axios.get).not.toHaveBeenCalled()
})

it("should trim whitespace from baseUrl", async () => {
const mockResponse = {
data: {
data: [{ id: "gpt-4" }, { id: "gpt-3.5-turbo" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

const result = await getOpenAiModels(" https://api.openai.com/v1 ", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://api.openai.com/v1/models", expect.any(Object))
expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"])
})

it("should handle baseUrl with trailing spaces", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }, { id: "model-2" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

const result = await getOpenAiModels("https://api.example.com/v1 ", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object))
expect(result).toEqual(["model-1", "model-2"])
})

it("should handle baseUrl with leading spaces", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

const result = await getOpenAiModels(" https://api.example.com/v1", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object))
expect(result).toEqual(["model-1"])
})

it("should return empty array for invalid URL after trimming", async () => {
const result = await getOpenAiModels(" not-a-valid-url ", "test-key")
expect(result).toEqual([])
expect(axios.get).not.toHaveBeenCalled()
})

it("should include authorization header when apiKey is provided", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

await getOpenAiModels("https://api.example.com/v1", "test-api-key")

expect(axios.get).toHaveBeenCalledWith(
"https://api.example.com/v1/models",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer test-api-key",
}),
}),
)
})

it("should include custom headers when provided", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

const customHeaders = {
"X-Custom-Header": "custom-value",
}

await getOpenAiModels("https://api.example.com/v1", "test-key", customHeaders)

expect(axios.get).toHaveBeenCalledWith(
"https://api.example.com/v1/models",
expect.objectContaining({
headers: expect.objectContaining({
"X-Custom-Header": "custom-value",
Authorization: "Bearer test-key",
}),
}),
)
})

it("should handle API errors gracefully", async () => {
vi.mocked(axios.get).mockRejectedValueOnce(new Error("Network error"))

const result = await getOpenAiModels("https://api.example.com/v1", "test-key")

expect(result).toEqual([])
})

it("should handle malformed response data", async () => {
vi.mocked(axios.get).mockResolvedValueOnce({ data: null })

const result = await getOpenAiModels("https://api.example.com/v1", "test-key")

expect(result).toEqual([])
})

it("should deduplicate model IDs", async () => {
const mockResponse = {
data: {
data: [{ id: "gpt-4" }, { id: "gpt-4" }, { id: "gpt-3.5-turbo" }, { id: "gpt-4" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

const result = await getOpenAiModels("https://api.example.com/v1", "test-key")

expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"])
})
})
7 changes: 5 additions & 2 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
return []
}

if (!URL.canParse(baseUrl)) {
// Trim whitespace from baseUrl to handle cases where users accidentally include spaces
const trimmedBaseUrl = baseUrl.trim()

if (!URL.canParse(trimmedBaseUrl)) {
return []
}

Expand All @@ -434,7 +437,7 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
config["headers"] = headers
}

const response = await axios.get(`${baseUrl}/models`, config)
const response = await axios.get(`${trimmedBaseUrl}/models`, config)
const modelsArray = response.data?.data?.map((model: any) => model.id) || []
return [...new Set<string>(modelsArray)]
} catch (error) {
Expand Down