diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 3fa7094d87..b3e1330ef1 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -184,6 +184,7 @@ const openAiSchema = baseProviderSettingsSchema.extend({ const ollamaSchema = baseProviderSettingsSchema.extend({ ollamaModelId: z.string().optional(), ollamaBaseUrl: z.string().optional(), + ollamaContextWindow: z.number().optional(), }) const vsCodeLmSchema = baseProviderSettingsSchema.extend({ diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index f8792937db..2657cdc27d 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -100,6 +100,103 @@ describe("NativeOllamaHandler", () => { expect(results.some((r) => r.type === "reasoning")).toBe(true) expect(results.some((r) => r.type === "text")).toBe(true) }) + + describe("context window configuration", () => { + it("should use custom context window when ollamaContextWindow is provided", async () => { + const customContextWindow = 48000 + const optionsWithCustomContext: ApiHandlerOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434", + ollamaContextWindow: customContextWindow, + } + handler = new NativeOllamaHandler(optionsWithCustomContext) + + // Mock the chat response + mockChat.mockImplementation(async function* () { + yield { + message: { content: "Test response" }, + eval_count: 10, + prompt_eval_count: 5, + } + }) + + // Create a message to trigger the chat call + const generator = handler.createMessage("System prompt", [{ role: "user", content: "Test message" }]) + + // Consume the generator + const results = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Verify that chat was called with the custom context window + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + options: expect.objectContaining({ + num_ctx: customContextWindow, + }), + }), + ) + }) + + it("should use model's default context window when ollamaContextWindow is not provided", async () => { + // Mock the chat response + mockChat.mockImplementation(async function* () { + yield { + message: { content: "Test response" }, + eval_count: 10, + prompt_eval_count: 5, + } + }) + + // Create a message to trigger the chat call + const generator = handler.createMessage("System prompt", [{ role: "user", content: "Test message" }]) + + // Consume the generator + const results = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Verify that chat was called with the model's default context window (4096) + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + options: expect.objectContaining({ + num_ctx: 4096, + }), + }), + ) + }) + + it("should use custom context window in completePrompt method", async () => { + const customContextWindow = 48000 + const optionsWithCustomContext: ApiHandlerOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434", + ollamaContextWindow: customContextWindow, + } + handler = new NativeOllamaHandler(optionsWithCustomContext) + + // Mock the chat response + mockChat.mockResolvedValue({ + message: { content: "Test response" }, + }) + + // Call completePrompt + await handler.completePrompt("Test prompt") + + // Verify that chat was called with the custom context window + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + options: expect.objectContaining({ + num_ctx: customContextWindow, + }), + }), + ) + }) + }) }) describe("completePrompt", () => { @@ -115,6 +212,7 @@ describe("NativeOllamaHandler", () => { messages: [{ role: "user", content: "Tell me a joke" }], stream: false, options: { + num_ctx: 4096, temperature: 0, }, }) diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 8ab4ebe2e1..f895fb7641 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -181,7 +181,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio messages: ollamaMessages, stream: true, options: { - num_ctx: modelInfo.contextWindow, + num_ctx: this.options.ollamaContextWindow || modelInfo.contextWindow, temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), }, }) @@ -262,7 +262,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio async completePrompt(prompt: string): Promise { try { const client = this.ensureClient() - const { id: modelId } = await this.fetchModel() + const { id: modelId, info: modelInfo } = await this.fetchModel() const useR1Format = modelId.toLowerCase().includes("deepseek-r1") const response = await client.chat({ @@ -270,6 +270,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio messages: [{ role: "user", content: prompt }], stream: false, options: { + num_ctx: this.options.ollamaContextWindow || modelInfo.contextWindow, temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), }, })