Skip to content

Commit 5bd5e29

Browse files
committed
Fix o1-pro on OpenRouter
1 parent 61e122d commit 5bd5e29

File tree

7 files changed

+123
-33
lines changed

7 files changed

+123
-33
lines changed

src/api/providers/__tests__/ollama.test.ts renamed to src/api/providers/__tests__/ollama.spec.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
// npx vitest run api/providers/__tests__/ollama.spec.ts
2+
3+
import { vitest } from "vitest"
14
import { Anthropic } from "@anthropic-ai/sdk"
25

36
import { OllamaHandler } from "../ollama"
47
import { ApiHandlerOptions } from "../../../shared/api"
58

6-
// Mock OpenAI client
7-
const mockCreate = jest.fn()
8-
jest.mock("openai", () => {
9+
const mockCreate = vitest.fn()
10+
11+
vitest.mock("openai", () => {
912
return {
1013
__esModule: true,
11-
default: jest.fn().mockImplementation(() => ({
14+
default: vitest.fn().mockImplementation(() => ({
1215
chat: {
1316
completions: {
1417
create: mockCreate.mockImplementation(async (options) => {

src/api/providers/__tests__/openai-native.test.ts renamed to src/api/providers/__tests__/openai-native.spec.ts

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
// npx jest src/api/providers/__tests__/openai-native.test.ts
1+
// npx vitest run api/providers/__tests__/openai-native.spec.ts
22

3+
import { vitest } from "vitest"
34
import { Anthropic } from "@anthropic-ai/sdk"
45

56
import { OpenAiNativeHandler } from "../openai-native"
67
import { ApiHandlerOptions } from "../../../shared/api"
78

89
// Mock OpenAI client
9-
const mockCreate = jest.fn()
10+
const mockCreate = vitest.fn()
1011

11-
jest.mock("openai", () => {
12+
vitest.mock("openai", () => {
1213
return {
1314
__esModule: true,
14-
default: jest.fn().mockImplementation(() => ({
15+
default: vitest.fn().mockImplementation(() => ({
1516
chat: {
1617
completions: {
1718
create: mockCreate.mockImplementation(async (options) => {
@@ -372,6 +373,75 @@ describe("OpenAiNativeHandler", () => {
372373
})
373374
})
374375

376+
describe("temperature parameter handling", () => {
377+
it("should include temperature for models that support it", async () => {
378+
// Test with gpt-4.1 which supports temperature
379+
handler = new OpenAiNativeHandler({
380+
apiModelId: "gpt-4.1",
381+
openAiNativeApiKey: "test-api-key",
382+
})
383+
384+
await handler.completePrompt("Test prompt")
385+
expect(mockCreate).toHaveBeenCalledWith({
386+
model: "gpt-4.1",
387+
messages: [{ role: "user", content: "Test prompt" }],
388+
temperature: 0,
389+
})
390+
})
391+
392+
it("should strip temperature for o1 family models", async () => {
393+
const o1Models = ["o1", "o1-preview", "o1-mini"]
394+
395+
for (const modelId of o1Models) {
396+
handler = new OpenAiNativeHandler({
397+
apiModelId: modelId,
398+
openAiNativeApiKey: "test-api-key",
399+
})
400+
401+
mockCreate.mockClear()
402+
await handler.completePrompt("Test prompt")
403+
404+
const callArgs = mockCreate.mock.calls[0][0]
405+
// Temperature should be undefined for o1 models
406+
expect(callArgs.temperature).toBeUndefined()
407+
expect(callArgs.model).toBe(modelId)
408+
}
409+
})
410+
411+
it("should strip temperature for o3-mini model", async () => {
412+
handler = new OpenAiNativeHandler({
413+
apiModelId: "o3-mini",
414+
openAiNativeApiKey: "test-api-key",
415+
})
416+
417+
await handler.completePrompt("Test prompt")
418+
419+
const callArgs = mockCreate.mock.calls[0][0]
420+
// Temperature should be undefined for o3-mini models
421+
expect(callArgs.temperature).toBeUndefined()
422+
expect(callArgs.model).toBe("o3-mini")
423+
expect(callArgs.reasoning_effort).toBe("medium")
424+
})
425+
426+
it("should strip temperature in streaming mode for unsupported models", async () => {
427+
handler = new OpenAiNativeHandler({
428+
apiModelId: "o1",
429+
openAiNativeApiKey: "test-api-key",
430+
})
431+
432+
const stream = handler.createMessage(systemPrompt, messages)
433+
// Consume the stream
434+
for await (const _chunk of stream) {
435+
// Just consume the stream
436+
}
437+
438+
const callArgs = mockCreate.mock.calls[0][0]
439+
expect(callArgs).not.toHaveProperty("temperature")
440+
expect(callArgs.model).toBe("o1")
441+
expect(callArgs.stream).toBe(true)
442+
})
443+
})
444+
375445
describe("getModel", () => {
376446
it("should return model info", () => {
377447
const modelInfo = handler.getModel()

src/api/providers/__tests__/openai-usage-tracking.test.ts renamed to src/api/providers/__tests__/openai-usage-tracking.spec.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
import { OpenAiHandler } from "../openai"
2-
import { ApiHandlerOptions } from "../../../shared/api"
1+
// npx vitest run api/providers/__tests__/openai-usage-tracking.spec.ts
2+
3+
import { vitest } from "vitest"
34
import { Anthropic } from "@anthropic-ai/sdk"
45

5-
// Mock OpenAI client with multiple chunks that contain usage data
6-
const mockCreate = jest.fn()
7-
jest.mock("openai", () => {
6+
import { ApiHandlerOptions } from "../../../shared/api"
7+
import { OpenAiHandler } from "../openai"
8+
9+
const mockCreate = vitest.fn()
10+
11+
vitest.mock("openai", () => {
812
return {
913
__esModule: true,
10-
default: jest.fn().mockImplementation(() => ({
14+
default: vitest.fn().mockImplementation(() => ({
1115
chat: {
1216
completions: {
1317
create: mockCreate.mockImplementation(async (options) => {

src/api/providers/__tests__/openai.test.ts renamed to src/api/providers/__tests__/openai.spec.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
// npx jest src/api/providers/__tests__/openai.test.ts
1+
// npx vitest run api/providers/__tests__/openai.spec.ts
22

3+
import { vitest, vi } from "vitest"
34
import { OpenAiHandler } from "../openai"
45
import { ApiHandlerOptions } from "../../../shared/api"
56
import { Anthropic } from "@anthropic-ai/sdk"
7+
import OpenAI from "openai"
68

7-
// Mock OpenAI client
8-
const mockCreate = jest.fn()
9-
jest.mock("openai", () => {
9+
const mockCreate = vitest.fn()
10+
11+
vitest.mock("openai", () => {
12+
const mockConstructor = vitest.fn()
1013
return {
1114
__esModule: true,
12-
default: jest.fn().mockImplementation(() => ({
15+
default: mockConstructor.mockImplementation(() => ({
1316
chat: {
1417
completions: {
1518
create: mockCreate.mockImplementation(async (options) => {
@@ -94,10 +97,8 @@ describe("OpenAiHandler", () => {
9497
})
9598

9699
it("should set default headers correctly", () => {
97-
// Get the mock constructor from the jest mock system
98-
const openAiMock = jest.requireMock("openai").default
99-
100-
expect(openAiMock).toHaveBeenCalledWith({
100+
// Check that the OpenAI constructor was called with correct parameters
101+
expect(vi.mocked(OpenAI)).toHaveBeenCalledWith({
101102
baseURL: expect.any(String),
102103
apiKey: expect.any(String),
103104
defaultHeaders: {

src/api/providers/openai-native.ts

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
165165

166166
const info: ModelInfo = openAiNativeModels[id]
167167

168-
const { temperature, ...params } = getModelParams({
168+
const params = getModelParams({
169169
format: "openai",
170170
modelId: id,
171171
model: info,
@@ -175,13 +175,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
175175

176176
// The o3 models are named like "o3-mini-[reasoning-effort]", which are
177177
// not valid model ids, so we need to strip the suffix.
178-
// Also note that temperature is not supported for o1 and o3-mini.
179-
return {
180-
id: id.startsWith("o3-mini") ? "o3-mini" : id,
181-
info,
182-
...params,
183-
temperature: id.startsWith("o1") || id.startsWith("o3-mini") ? undefined : temperature,
184-
}
178+
return { id: id.startsWith("o3-mini") ? "o3-mini" : id, info, ...params }
185179
}
186180

187181
async completePrompt(prompt: string): Promise<string> {

src/api/transform/model-params.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type GetModelParamsOptions<T extends "openai" | "anthropic" | "openrouter"> = {
2525

2626
type BaseModelParams = {
2727
maxTokens: number | undefined
28-
temperature: number
28+
temperature: number | undefined
2929
reasoningEffort: "low" | "medium" | "high" | undefined
3030
reasoningBudget: number | undefined
3131
}
@@ -114,12 +114,27 @@ export function getModelParams({
114114
reasoning: getAnthropicReasoning({ model, reasoningBudget, reasoningEffort, settings }),
115115
}
116116
} else if (format === "openai") {
117+
// Special case for o1 and o3-mini, which don't support temperature.
118+
// TODO: Add a `supportsTemperature` field to the model info.
119+
if (modelId.startsWith("o1") || modelId.startsWith("o3-mini")) {
120+
params.temperature = undefined
121+
}
122+
117123
return {
118124
format,
119125
...params,
120126
reasoning: getOpenAiReasoning({ model, reasoningBudget, reasoningEffort, settings }),
121127
}
122128
} else {
129+
// Special case for o1-pro, which doesn't support temperature.
130+
// Note that OpenRouter's `supported_parameters` field includes
131+
// `temperature`, which is probably a bug.
132+
// TODO: Add a `supportsTemperature` field to the model info and populate
133+
// it appropriately in the OpenRouter fetcher.
134+
if (modelId === "openai/o1-pro") {
135+
params.temperature = undefined
136+
}
137+
123138
return {
124139
format,
125140
...params,

src/vitest.config.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import { defineConfig } from "vitest/config"
22

33
export default defineConfig({
4-
test: { include: ["**/__tests__/**/*.spec.ts"] },
4+
test: {
5+
include: ["**/__tests__/**/*.spec.ts"],
6+
globals: true,
7+
},
58
})

0 commit comments

Comments
 (0)