diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3faeb9f2191f..d514ddf02898 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -676,6 +676,9 @@ importers: node-ipc: specifier: ^12.0.0 version: 12.0.0 + ollama: + specifier: ^0.5.17 + version: 0.5.17 openai: specifier: ^5.0.0 version: 5.5.1(ws@8.18.3)(zod@3.25.61) @@ -7645,6 +7648,9 @@ packages: resolution: {integrity: sha512-gXah6aZrcUxjWg2zR2MwouP2eHlCBzdV4pygudehaKXSGW4v2AsRQUK+lwwXhii6KFZcunEnmSUoYp5CXibxtA==} engines: {node: '>= 0.4'} + ollama@0.5.17: + resolution: {integrity: sha512-q5LmPtk6GLFouS+3aURIVl+qcAOPC4+Msmx7uBb3pd+fxI55WnGjmLZ0yijI/CYy79x0QPGx3BwC3u5zv9fBvQ==} + on-finished@2.4.1: resolution: {integrity: sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==} engines: {node: '>= 0.8'} @@ -9655,6 +9661,9 @@ packages: resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==} engines: {node: '>=18'} + whatwg-fetch@3.6.20: + resolution: {integrity: sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==} + whatwg-mimetype@4.0.0: resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} engines: {node: '>=18'} @@ -13546,7 +13555,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: @@ -17683,6 +17692,10 @@ snapshots: define-properties: 1.2.1 es-object-atoms: 1.1.1 + ollama@0.5.17: + dependencies: + whatwg-fetch: 3.6.20 + on-finished@2.4.1: dependencies: ee-first: 1.1.1 @@ -20155,6 +20168,8 @@ snapshots: dependencies: iconv-lite: 0.6.3 + whatwg-fetch@3.6.20: {} + whatwg-mimetype@4.0.0: {} whatwg-url@14.2.0: diff --git a/src/api/index.ts b/src/api/index.ts index c29c230b063b..92a5c95770d6 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -13,7 +13,6 @@ import { VertexHandler, AnthropicVertexHandler, OpenAiHandler, - OllamaHandler, LmStudioHandler, GeminiHandler, OpenAiNativeHandler, @@ -37,6 +36,7 @@ import { ZAiHandler, FireworksHandler, } from "./providers" +import { NativeOllamaHandler } from "./providers/native-ollama" export interface SingleCompletionHandler { completePrompt(prompt: string): Promise @@ -95,7 +95,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { case "openai": return new OpenAiHandler(options) case "ollama": - return new OllamaHandler(options) + return new NativeOllamaHandler(options) case "lmstudio": return new LmStudioHandler(options) case "gemini": diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts new file mode 100644 index 000000000000..f8792937dbcd --- /dev/null +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -0,0 +1,162 @@ +// npx vitest run api/providers/__tests__/native-ollama.spec.ts + +import { NativeOllamaHandler } from "../native-ollama" +import { ApiHandlerOptions } from "../../../shared/api" + +// Mock the ollama package +const mockChat = vitest.fn() +vitest.mock("ollama", () => { + return { + Ollama: vitest.fn().mockImplementation(() => ({ + chat: mockChat, + })), + Message: vitest.fn(), + } +}) + +// Mock the getOllamaModels function +vitest.mock("../fetchers/ollama", () => ({ + getOllamaModels: vitest.fn().mockResolvedValue({ + llama2: { + contextWindow: 4096, + maxTokens: 4096, + supportsImages: false, + supportsPromptCache: false, + }, + }), +})) + +describe("NativeOllamaHandler", () => { + let handler: NativeOllamaHandler + + beforeEach(() => { + vitest.clearAllMocks() + + const options: ApiHandlerOptions = { + apiModelId: "llama2", + ollamaModelId: "llama2", + ollamaBaseUrl: "http://localhost:11434", + } + + handler = new NativeOllamaHandler(options) + }) + + describe("createMessage", () => { + it("should stream messages from Ollama", async () => { + // Mock the chat response as an async generator + mockChat.mockImplementation(async function* () { + yield { + message: { content: "Hello" }, + eval_count: undefined, + prompt_eval_count: undefined, + } + yield { + message: { content: " world" }, + eval_count: 2, + prompt_eval_count: 10, + } + }) + + const systemPrompt = "You are a helpful assistant" + const messages = [{ role: "user" as const, content: "Hi there" }] + + const stream = handler.createMessage(systemPrompt, messages) + const results = [] + + for await (const chunk of stream) { + results.push(chunk) + } + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ type: "text", text: "Hello" }) + expect(results[1]).toEqual({ type: "text", text: " world" }) + expect(results[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 2 }) + }) + + it("should handle DeepSeek R1 models with reasoning detection", async () => { + const options: ApiHandlerOptions = { + apiModelId: "deepseek-r1", + ollamaModelId: "deepseek-r1", + ollamaBaseUrl: "http://localhost:11434", + } + + handler = new NativeOllamaHandler(options) + + // Mock response with thinking tags + mockChat.mockImplementation(async function* () { + yield { message: { content: "Let me think" } } + yield { message: { content: " about this" } } + yield { message: { content: "The answer is 42" } } + }) + + const stream = handler.createMessage("System", [{ role: "user" as const, content: "Question?" }]) + const results = [] + + for await (const chunk of stream) { + results.push(chunk) + } + + // Should detect reasoning vs regular text + expect(results.some((r) => r.type === "reasoning")).toBe(true) + expect(results.some((r) => r.type === "text")).toBe(true) + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt without streaming", async () => { + mockChat.mockResolvedValue({ + message: { content: "This is the response" }, + }) + + const result = await handler.completePrompt("Tell me a joke") + + expect(mockChat).toHaveBeenCalledWith({ + model: "llama2", + messages: [{ role: "user", content: "Tell me a joke" }], + stream: false, + options: { + temperature: 0, + }, + }) + expect(result).toBe("This is the response") + }) + }) + + describe("error handling", () => { + it("should handle connection refused errors", async () => { + const error = new Error("ECONNREFUSED") as any + error.code = "ECONNREFUSED" + mockChat.mockRejectedValue(error) + + const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }]) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("Ollama service is not running") + }) + + it("should handle model not found errors", async () => { + const error = new Error("Not found") as any + error.status = 404 + mockChat.mockRejectedValue(error) + + const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }]) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("Model llama2 not found in Ollama") + }) + }) + + describe("getModel", () => { + it("should return the configured model", () => { + const model = handler.getModel() + expect(model.id).toBe("llama2") + expect(model.info).toBeDefined() + }) + }) +}) diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts new file mode 100644 index 000000000000..8ab4ebe2e136 --- /dev/null +++ b/src/api/providers/native-ollama.ts @@ -0,0 +1,285 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { Message, Ollama } from "ollama" +import { ModelInfo, openAiModelInfoSaneDefaults, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types" +import { ApiStream } from "../transform/stream" +import { BaseProvider } from "./base-provider" +import type { ApiHandlerOptions } from "../../shared/api" +import { getOllamaModels } from "./fetchers/ollama" +import { XmlMatcher } from "../../utils/xml-matcher" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +function convertToOllamaMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] { + const ollamaMessages: Message[] = [] + + for (const anthropicMessage of anthropicMessages) { + if (typeof anthropicMessage.content === "string") { + ollamaMessages.push({ + role: anthropicMessage.role, + content: anthropicMessage.content, + }) + } else { + if (anthropicMessage.role === "user") { + const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolResultBlockParam[] + }>( + (acc, part) => { + if (part.type === "tool_result") { + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) + } + return acc + }, + { nonToolMessages: [], toolMessages: [] }, + ) + + // Process tool result messages FIRST since they must follow the tool use messages + const toolResultImages: string[] = [] + toolMessages.forEach((toolMessage) => { + // The Anthropic SDK allows tool results to be a string or an array of text and image blocks, enabling rich and structured content. In contrast, the Ollama SDK only supports tool results as a single string, so we map the Anthropic tool result parts into one concatenated string to maintain compatibility. + let content: string + + if (typeof toolMessage.content === "string") { + content = toolMessage.content + } else { + content = + toolMessage.content + ?.map((part) => { + if (part.type === "image") { + // Handle base64 images only (Anthropic SDK uses base64) + // Ollama expects raw base64 strings, not data URLs + if ("source" in part && part.source.type === "base64") { + toolResultImages.push(part.source.data) + } + return "(see following user message for image)" + } + return part.text + }) + .join("\n") ?? "" + } + ollamaMessages.push({ + role: "user", + images: toolResultImages.length > 0 ? toolResultImages : undefined, + content: content, + }) + }) + + // Process non-tool messages + if (nonToolMessages.length > 0) { + // Separate text and images for Ollama + const textContent = nonToolMessages + .filter((part) => part.type === "text") + .map((part) => part.text) + .join("\n") + + const imageData: string[] = [] + nonToolMessages.forEach((part) => { + if (part.type === "image" && "source" in part && part.source.type === "base64") { + // Ollama expects raw base64 strings, not data URLs + imageData.push(part.source.data) + } + }) + + ollamaMessages.push({ + role: "user", + content: textContent, + images: imageData.length > 0 ? imageData : undefined, + }) + } + } else if (anthropicMessage.role === "assistant") { + const { nonToolMessages } = anthropicMessage.content.reduce<{ + nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] + toolMessages: Anthropic.ToolUseBlockParam[] + }>( + (acc, part) => { + if (part.type === "tool_use") { + acc.toolMessages.push(part) + } else if (part.type === "text" || part.type === "image") { + acc.nonToolMessages.push(part) + } // assistant cannot send tool_result messages + return acc + }, + { nonToolMessages: [], toolMessages: [] }, + ) + + // Process non-tool messages + let content: string = "" + if (nonToolMessages.length > 0) { + content = nonToolMessages + .map((part) => { + if (part.type === "image") { + return "" // impossible as the assistant cannot send images + } + return part.text + }) + .join("\n") + } + + ollamaMessages.push({ + role: "assistant", + content, + }) + } + } + } + + return ollamaMessages +} + +export class NativeOllamaHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: Ollama | undefined + protected models: Record = {} + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + } + + private ensureClient(): Ollama { + if (!this.client) { + try { + this.client = new Ollama({ + host: this.options.ollamaBaseUrl || "http://localhost:11434", + // Note: The ollama npm package handles timeouts internally + }) + } catch (error: any) { + throw new Error(`Error creating Ollama client: ${error.message}`) + } + } + return this.client + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const client = this.ensureClient() + const { id: modelId, info: modelInfo } = await this.fetchModel() + const useR1Format = modelId.toLowerCase().includes("deepseek-r1") + + const ollamaMessages: Message[] = [ + { role: "system", content: systemPrompt }, + ...convertToOllamaMessages(messages), + ] + + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) + + try { + // Create the actual API request promise + const stream = await client.chat({ + model: modelId, + messages: ollamaMessages, + stream: true, + options: { + num_ctx: modelInfo.contextWindow, + temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), + }, + }) + + let totalInputTokens = 0 + let totalOutputTokens = 0 + + try { + for await (const chunk of stream) { + if (typeof chunk.message.content === "string") { + // Process content through matcher for reasoning detection + for (const matcherChunk of matcher.update(chunk.message.content)) { + yield matcherChunk + } + } + + // Handle token usage if available + if (chunk.eval_count !== undefined || chunk.prompt_eval_count !== undefined) { + if (chunk.prompt_eval_count) { + totalInputTokens = chunk.prompt_eval_count + } + if (chunk.eval_count) { + totalOutputTokens = chunk.eval_count + } + } + } + + // Yield any remaining content from the matcher + for (const chunk of matcher.final()) { + yield chunk + } + + // Yield usage information if available + if (totalInputTokens > 0 || totalOutputTokens > 0) { + yield { + type: "usage", + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + } + } + } catch (streamError: any) { + console.error("Error processing Ollama stream:", streamError) + throw new Error(`Ollama stream processing error: ${streamError.message || "Unknown error"}`) + } + } catch (error: any) { + // Enhance error reporting + const statusCode = error.status || error.statusCode + const errorMessage = error.message || "Unknown error" + + if (error.code === "ECONNREFUSED") { + throw new Error( + `Ollama service is not running at ${this.options.ollamaBaseUrl || "http://localhost:11434"}. Please start Ollama first.`, + ) + } else if (statusCode === 404) { + throw new Error( + `Model ${this.getModel().id} not found in Ollama. Please pull the model first with: ollama pull ${this.getModel().id}`, + ) + } + + console.error(`Ollama API error (${statusCode || "unknown"}): ${errorMessage}`) + throw error + } + } + + async fetchModel() { + this.models = await getOllamaModels(this.options.ollamaBaseUrl) + return this.getModel() + } + + override getModel(): { id: string; info: ModelInfo } { + const modelId = this.options.ollamaModelId || "" + return { + id: modelId, + info: this.models[modelId] || openAiModelInfoSaneDefaults, + } + } + + async completePrompt(prompt: string): Promise { + try { + const client = this.ensureClient() + const { id: modelId } = await this.fetchModel() + const useR1Format = modelId.toLowerCase().includes("deepseek-r1") + + const response = await client.chat({ + model: modelId, + messages: [{ role: "user", content: prompt }], + stream: false, + options: { + temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), + }, + }) + + return response.message?.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Ollama completion error: ${error.message}`) + } + throw error + } + } +} diff --git a/src/package.json b/src/package.json index 604354a31c17..4928f450b33b 100644 --- a/src/package.json +++ b/src/package.json @@ -458,6 +458,7 @@ "monaco-vscode-textmate-theme-converter": "^0.1.7", "node-cache": "^5.1.2", "node-ipc": "^12.0.0", + "ollama": "^0.5.17", "openai": "^5.0.0", "os-name": "^6.0.0", "p-limit": "^6.2.0",