diff --git a/packages/agent/src/core/llm/provider.ts b/packages/agent/src/core/llm/provider.ts index 379bbef..ae0651a 100644 --- a/packages/agent/src/core/llm/provider.ts +++ b/packages/agent/src/core/llm/provider.ts @@ -3,6 +3,7 @@ */ import { AnthropicProvider } from './providers/anthropic.js'; +import { OllamaProvider } from './providers/ollama.js'; import { ProviderOptions, GenerateOptions, LLMResponse } from './types.js'; /** @@ -39,6 +40,7 @@ const providerFactories: Record< (model: string, options: ProviderOptions) => LLMProvider > = { anthropic: (model, options) => new AnthropicProvider(model, options), + ollama: (model, options) => new OllamaProvider(model, options), }; /** diff --git a/packages/agent/src/core/llm/providers/ollama.ts b/packages/agent/src/core/llm/providers/ollama.ts new file mode 100644 index 0000000..c3a4869 --- /dev/null +++ b/packages/agent/src/core/llm/providers/ollama.ts @@ -0,0 +1,177 @@ +/** + * Ollama provider implementation + */ + +import { TokenUsage } from '../../tokens.js'; +import { LLMProvider } from '../provider.js'; +import { + GenerateOptions, + LLMResponse, + Message, + ProviderOptions, +} from '../types.js'; + +/** + * Ollama-specific options + */ +export interface OllamaOptions extends ProviderOptions { + baseUrl?: string; +} + +/** + * Ollama provider implementation + */ +export class OllamaProvider implements LLMProvider { + name: string = 'ollama'; + provider: string = 'ollama.chat'; + model: string; + private baseUrl: string; + + constructor(model: string, options: OllamaOptions = {}) { + this.model = model; + this.baseUrl = + options.baseUrl || + process.env.OLLAMA_BASE_URL || + 'http://localhost:11434'; + + // Ensure baseUrl doesn't end with a slash + if (this.baseUrl.endsWith('/')) { + this.baseUrl = this.baseUrl.slice(0, -1); + } + } + + /** + * Generate text using Ollama API + */ + async generateText(options: GenerateOptions): Promise { + const { + messages, + functions, + temperature = 0.7, + maxTokens, + topP, + frequencyPenalty, + presencePenalty, + } = options; + + // Format messages for Ollama API + const formattedMessages = this.formatMessages(messages); + + try { + // Prepare request options + const requestOptions: any = { + model: this.model, + messages: formattedMessages, + stream: false, + options: { + temperature: temperature, + // Ollama uses top_k instead of top_p, but we'll include top_p if provided + ...(topP !== undefined && { top_p: topP }), + ...(frequencyPenalty !== undefined && { + frequency_penalty: frequencyPenalty, + }), + ...(presencePenalty !== undefined && { + presence_penalty: presencePenalty, + }), + }, + }; + + // Add max_tokens if provided + if (maxTokens !== undefined) { + requestOptions.options.num_predict = maxTokens; + } + + // Add functions/tools if provided + if (functions && functions.length > 0) { + requestOptions.tools = functions.map((fn) => ({ + name: fn.name, + description: fn.description, + parameters: fn.parameters, + })); + } + + // Make the API request + const response = await fetch(`${this.baseUrl}/api/chat`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requestOptions), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Ollama API error: ${response.status} ${errorText}`); + } + + const data = await response.json(); + + // Extract content and tool calls + const content = data.message?.content || ''; + const toolCalls = + data.message?.tool_calls?.map((toolCall: any) => ({ + id: + toolCall.id || + `tool-${Date.now()}-${Math.random().toString(36).substring(2, 11)}`, + name: toolCall.name, + content: JSON.stringify(toolCall.args || toolCall.arguments || {}), + })) || []; + + // Create token usage from response data + const tokenUsage = new TokenUsage(); + tokenUsage.input = data.prompt_eval_count || 0; + tokenUsage.output = data.eval_count || 0; + + return { + text: content, + toolCalls: toolCalls, + tokenUsage: tokenUsage, + }; + } catch (error) { + throw new Error(`Error calling Ollama API: ${(error as Error).message}`); + } + } + + /** + * Format messages for Ollama API + */ + private formatMessages(messages: Message[]): any[] { + return messages.map((msg) => { + if ( + msg.role === 'user' || + msg.role === 'assistant' || + msg.role === 'system' + ) { + return { + role: msg.role, + content: msg.content, + }; + } else if (msg.role === 'tool_result') { + // Ollama expects tool results as a 'tool' role + return { + role: 'tool', + content: msg.content, + tool_call_id: msg.tool_use_id, + }; + } else if (msg.role === 'tool_use') { + // We'll convert tool_use to assistant messages with tool_calls + return { + role: 'assistant', + content: '', + tool_calls: [ + { + id: msg.id, + name: msg.name, + arguments: msg.content, + }, + ], + }; + } + // Default fallback for unknown message types + return { + role: 'user', + content: (msg as any).content || '', + }; + }); + } +} diff --git a/packages/agent/src/core/toolAgent/config.test.ts b/packages/agent/src/core/toolAgent/config.test.ts index 8c37501..eebe3eb 100644 --- a/packages/agent/src/core/toolAgent/config.test.ts +++ b/packages/agent/src/core/toolAgent/config.test.ts @@ -9,20 +9,27 @@ describe('getModel', () => { expect(model.provider).toBe('anthropic.messages'); }); - /* - - it('should return the correct model for openai', () => { - const model = getModel('openai', 'gpt-4o-2024-05-13'); + it('should return the correct model for ollama', () => { + const model = getModel('ollama', 'llama3'); expect(model).toBeDefined(); - expect(model.provider).toBe('openai.chat'); + expect(model.provider).toBe('ollama.chat'); }); - it('should return the correct model for ollama', () => { - const model = getModel('ollama', 'llama3'); + it('should return the correct model for ollama with custom base URL', () => { + const model = getModel('ollama', 'llama3', { + ollamaBaseUrl: 'http://custom-ollama:11434', + }); expect(model).toBeDefined(); expect(model.provider).toBe('ollama.chat'); }); + /* + it('should return the correct model for openai', () => { + const model = getModel('openai', 'gpt-4o-2024-05-13'); + expect(model).toBeDefined(); + expect(model.provider).toBe('openai.chat'); + }); + it('should return the correct model for xai', () => { const model = getModel('xai', 'grok-1'); expect(model).toBeDefined(); @@ -34,7 +41,7 @@ describe('getModel', () => { expect(model).toBeDefined(); expect(model.provider).toBe('mistral.chat'); }); -*/ + */ it('should throw an error for unknown provider', () => { expect(() => { diff --git a/packages/agent/src/core/toolAgent/config.ts b/packages/agent/src/core/toolAgent/config.ts index 29737c9..fe53a4c 100644 --- a/packages/agent/src/core/toolAgent/config.ts +++ b/packages/agent/src/core/toolAgent/config.ts @@ -8,22 +8,23 @@ import { ToolContext } from '../types'; /** * Available model providers */ -export type ModelProvider = 'anthropic'; +export type ModelProvider = 'anthropic' | 'ollama'; /* | 'openai' - | 'ollama' | 'xai' | 'mistral'*/ /** * Get the model instance based on provider and model name */ -export function getModel(provider: ModelProvider, model: string): LLMProvider { +export function getModel( + provider: ModelProvider, + model: string, + options?: { ollamaBaseUrl?: string }, +): LLMProvider { switch (provider) { case 'anthropic': return createProvider('anthropic', model); - /*case 'openai': - return createProvider('openai', model); case 'ollama': if (options?.ollamaBaseUrl) { return createProvider('ollama', model, { @@ -31,6 +32,8 @@ export function getModel(provider: ModelProvider, model: string): LLMProvider { }); } return createProvider('ollama', model); + /*case 'openai': + return createProvider('openai', model); case 'xai': return createProvider('xai', model); case 'mistral': diff --git a/packages/agent/src/utils/errors.ts b/packages/agent/src/utils/errors.ts index 5276381..b343a0b 100644 --- a/packages/agent/src/utils/errors.ts +++ b/packages/agent/src/utils/errors.ts @@ -21,7 +21,7 @@ export const providerConfig: Record< docsUrl: 'https://mycoder.ai/docs/getting-started/mistral', },*/ // No API key needed for ollama as it uses a local server - //ollama: undefined, + ollama: undefined, }; /** diff --git a/packages/cli/src/commands/$default.ts b/packages/cli/src/commands/$default.ts index 1bc50e2..fc0d9aa 100644 --- a/packages/cli/src/commands/$default.ts +++ b/packages/cli/src/commands/$default.ts @@ -136,8 +136,15 @@ export const command: CommandModule = { process.env[keyName] = configApiKey; logger.debug(`Using ${keyName} from configuration`); } + } else if (userModelProvider === 'ollama') { + // For Ollama, we check if the base URL is set + const ollamaBaseUrl = argv.ollamaBaseUrl || userConfig.ollamaBaseUrl; + logger.debug(`Using Ollama with base URL: ${ollamaBaseUrl}`); + } else { + // Unknown provider + logger.error(`Unknown provider: ${userModelProvider}`); + throw new Error(`Unknown provider: ${userModelProvider}`); } - // No API key check needed for Ollama as it uses a local server let prompt: string | undefined; @@ -193,12 +200,14 @@ export const command: CommandModule = { const agentConfig = { ...DEFAULT_CONFIG, model: getModel( - userModelProvider as 'anthropic' /* + userModelProvider as 'anthropic' | 'ollama' /* | 'openai' - | 'ollama' | 'xai' | 'mistral'*/, userModelName, + { + ollamaBaseUrl: argv.ollamaBaseUrl || userConfig.ollamaBaseUrl, + }, ), maxTokens: userMaxTokens, temperature: userTemperature, diff --git a/packages/cli/src/options.ts b/packages/cli/src/options.ts index 5de958d..c4e68d7 100644 --- a/packages/cli/src/options.ts +++ b/packages/cli/src/options.ts @@ -18,6 +18,7 @@ export type SharedOptions = { readonly githubMode?: boolean; readonly userWarning?: boolean; readonly upgradeCheck?: boolean; + readonly ollamaBaseUrl?: string; }; export const sharedOptions = { @@ -36,7 +37,7 @@ export const sharedOptions = { provider: { type: 'string', description: 'AI model provider to use', - choices: ['anthropic' /*, 'openai', 'ollama', 'xai', 'mistral'*/], + choices: ['anthropic', 'ollama' /*, 'openai', 'xai', 'mistral'*/], } as const, model: { type: 'string', @@ -120,4 +121,8 @@ export const sharedOptions = { description: 'Disable version upgrade check (for automated/remote usage)', default: false, } as const, + ollamaBaseUrl: { + type: 'string', + description: 'Base URL for Ollama API (default: http://localhost:11434)', + } as const, }; diff --git a/packages/cli/src/settings/config.ts b/packages/cli/src/settings/config.ts index bbe90e3..06549e8 100644 --- a/packages/cli/src/settings/config.ts +++ b/packages/cli/src/settings/config.ts @@ -54,6 +54,8 @@ const defaultConfig = { customPrompt: '', profile: false, tokenCache: true, + // Ollama configuration + ollamaBaseUrl: 'http://localhost:11434', // API keys (empty by default) ANTHROPIC_API_KEY: '', };