Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 167 additions & 42 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,43 @@ const mockCreate = vitest.fn()

vitest.mock("openai", () => {
const mockConstructor = vitest.fn()
return {
__esModule: true,
default: mockConstructor.mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
const mockImplementation = () => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
}

return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
delta: {},
index: 0,
},
],
Expand All @@ -34,38 +58,17 @@ vitest.mock("openai", () => {
total_tokens: 15,
},
}
}

return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}
}),
},
},
}
}),
},
})),
},
})

return {
__esModule: true,
default: mockConstructor.mockImplementation(mockImplementation),
AzureOpenAI: mockConstructor.mockImplementation(mockImplementation),
}
})

Expand Down Expand Up @@ -105,6 +108,50 @@ describe("OpenAiHandler", () => {
expect(handlerWithCustomUrl).toBeInstanceOf(OpenAiHandler)
})

it("should normalize base URL to prevent /v1 duplication", () => {
// Test URL that already ends with /v1
const urlWithV1 = "https://custom.openai.com/v1"
const handler1 = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: urlWithV1,
})
expect(handler1).toBeInstanceOf(OpenAiHandler)

// Test URL without /v1 (should add it)
const urlWithoutV1 = "https://custom.openai.com"
const handler2 = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: urlWithoutV1,
})
expect(handler2).toBeInstanceOf(OpenAiHandler)

// Test URL with trailing slash (should add /v1)
const urlWithTrailingSlash = "https://custom.openai.com/"
const handler3 = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: urlWithTrailingSlash,
})
expect(handler3).toBeInstanceOf(OpenAiHandler)
})

it("should not modify Azure endpoints", () => {
// Test Azure OpenAI endpoint
const azureUrl = "https://myinstance.openai.azure.com/openai/deployments/mymodel"
const azureHandler = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: azureUrl,
})
expect(azureHandler).toBeInstanceOf(OpenAiHandler)

// Test Azure AI Inference Service endpoint
const azureAiUrl = "https://myinstance.services.ai.azure.com"
const azureAiHandler = new OpenAiHandler({
...mockOptions,
openAiBaseUrl: azureAiUrl,
})
expect(azureAiHandler).toBeInstanceOf(OpenAiHandler)
})

it("should set default headers correctly", () => {
// Check that the OpenAI constructor was called with correct parameters
expect(vi.mocked(OpenAI)).toHaveBeenCalledWith({
Expand Down Expand Up @@ -831,6 +878,84 @@ describe("getOpenAiModels", () => {
expect(result).toEqual(["model-1", "model-2"])
})

it("should normalize URLs to prevent /v1 duplication", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

// URL already ending with /v1 should not get another /v1
const result = await getOpenAiModels("https://custom.api.com/v1", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://custom.api.com/v1/models", expect.any(Object))
expect(result).toEqual(["model-1"])
})

it("should add /v1 to URLs that don't have it", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

// URL without /v1 should get /v1 added
const result = await getOpenAiModels("https://custom.api.com", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://custom.api.com/v1/models", expect.any(Object))
expect(result).toEqual(["model-1"])
})

it("should handle URLs with trailing slash correctly", async () => {
const mockResponse = {
data: {
data: [{ id: "model-1" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

// URL with trailing slash should get /v1 added correctly
const result = await getOpenAiModels("https://custom.api.com/", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://custom.api.com/v1/models", expect.any(Object))
expect(result).toEqual(["model-1"])
})

it("should not modify Azure endpoints", async () => {
const mockResponse = {
data: {
data: [{ id: "azure-model" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

// Azure endpoint should not be modified
const result = await getOpenAiModels("https://myinstance.openai.azure.com/openai/deployments", "test-key")

expect(axios.get).toHaveBeenCalledWith(
"https://myinstance.openai.azure.com/openai/deployments/models",
expect.any(Object),
)
expect(result).toEqual(["azure-model"])
})

it("should not modify Azure AI Inference Service endpoints", async () => {
const mockResponse = {
data: {
data: [{ id: "azure-ai-model" }],
},
}
vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)

// Azure AI Inference Service endpoint should not be modified
const result = await getOpenAiModels("https://myinstance.services.ai.azure.com", "test-key")

expect(axios.get).toHaveBeenCalledWith("https://myinstance.services.ai.azure.com/models", expect.any(Object))
expect(result).toEqual(["azure-ai-model"])
})

it("should handle baseUrl with leading spaces", async () => {
const mockResponse = {
data: {
Expand Down
67 changes: 60 additions & 7 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
super()
this.options = options

const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
const baseURL = this.normalizeBaseUrl(this.options.openAiBaseUrl ?? "https://api.openai.com/v1")
const apiKey = this.options.openAiApiKey ?? "not-provided"
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
const isAzureAiInference = this._isAzureAiInference(baseURL)
const urlHost = this._getUrlHost(baseURL)
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure

const headers = {
Expand Down Expand Up @@ -423,6 +423,41 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
return urlHost.endsWith(".services.ai.azure.com")
}

/**
* Normalizes the base URL to ensure it ends with /v1 for OpenAI-compatible APIs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The URL normalization logic (adding '/v1') is correctly implemented, but note that similar logic is repeated in both the normalizeBaseUrl method (lines 427–459) and in getOpenAiModels (lines 487–509). Consider extracting this logic into a shared utility function to reduce duplication and ease future maintenance.

This comment was generated because it violated a code review rule: irule_tTqpIuNs8DV0QFGj.

* but doesn't duplicate /v1 if it's already present
*/
private normalizeBaseUrl(baseUrl: string): string {
// Trim whitespace
let normalizedUrl = baseUrl.trim()

// For standard OpenAI API, keep as-is
if (normalizedUrl === "https://api.openai.com/v1") {
return normalizedUrl
}

// For Azure endpoints, don't modify
const urlHost = this._getUrlHost(normalizedUrl)
if (urlHost.endsWith(".azure.com") || urlHost.endsWith(".services.ai.azure.com")) {
return normalizedUrl
}

// For other OpenAI-compatible APIs, ensure /v1 is present but not duplicated
// The OpenAI SDK expects the base URL to include /v1 for compatibility
if (!normalizedUrl.endsWith("/v1")) {
// Remove trailing slash if present
if (normalizedUrl.endsWith("/")) {
normalizedUrl = normalizedUrl.slice(0, -1)
}
// Only add /v1 if it's not already there
if (!normalizedUrl.endsWith("/v1")) {
normalizedUrl = `${normalizedUrl}/v1`
}
}

return normalizedUrl
}

/**
* Adds max_completion_tokens to the request body if needed based on provider configuration
* Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation
Expand All @@ -449,13 +484,31 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
return []
}

// Trim whitespace from baseUrl to handle cases where users accidentally include spaces
const trimmedBaseUrl = baseUrl.trim()
// Normalize the base URL using the same logic as the OpenAiHandler
let normalizedUrl = baseUrl.trim()

if (!URL.canParse(trimmedBaseUrl)) {
if (!URL.canParse(normalizedUrl)) {
return []
}

// For Azure endpoints, don't modify
const urlHost = new URL(normalizedUrl).host
const isAzure = urlHost.endsWith(".azure.com") || urlHost.endsWith(".services.ai.azure.com")

if (!isAzure && normalizedUrl !== "https://api.openai.com/v1") {
// For other OpenAI-compatible APIs, ensure /v1 is present but not duplicated
if (!normalizedUrl.endsWith("/v1")) {
// Remove trailing slash if present
if (normalizedUrl.endsWith("/")) {
normalizedUrl = normalizedUrl.slice(0, -1)
}
// Only add /v1 if it's not already there
if (!normalizedUrl.endsWith("/v1")) {
normalizedUrl = `${normalizedUrl}/v1`
}
}
}

const config: Record<string, any> = {}
const headers: Record<string, string> = {
...DEFAULT_HEADERS,
Expand All @@ -470,7 +523,7 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
config["headers"] = headers
}

const response = await axios.get(`${trimmedBaseUrl}/models`, config)
const response = await axios.get(`${normalizedUrl}/models`, config)
const modelsArray = response.data?.data?.map((model: any) => model.id) || []
return [...new Set<string>(modelsArray)]
} catch (error) {
Expand Down
Loading