From 7bda811658e15b8dd41135cd9b2b90e9ea925e15 Mon Sep 17 00:00:00 2001 From: "Ben Houston (via MyCoder)" Date: Wed, 12 Mar 2025 02:01:47 +0000 Subject: [PATCH 1/2] feat(llm): add OpenAI support to LLM abstraction --- .../src/core/llm/__tests__/openai.test.ts | 221 ++++++++++++++++++ packages/agent/src/core/llm/provider.ts | 2 + .../agent/src/core/llm/providers/openai.ts | 207 ++++++++++++++++ packages/cli/src/settings/config.ts | 1 + 4 files changed, 431 insertions(+) create mode 100644 packages/agent/src/core/llm/__tests__/openai.test.ts create mode 100644 packages/agent/src/core/llm/providers/openai.ts diff --git a/packages/agent/src/core/llm/__tests__/openai.test.ts b/packages/agent/src/core/llm/__tests__/openai.test.ts new file mode 100644 index 0000000..2eaf476 --- /dev/null +++ b/packages/agent/src/core/llm/__tests__/openai.test.ts @@ -0,0 +1,221 @@ +import { describe, expect, it, vi, beforeEach } from 'vitest'; + +import { TokenUsage } from '../../tokens.js'; +import { OpenAIProvider } from '../providers/openai.js'; + +// Mock the OpenAI module +vi.mock('openai', () => { + // Create a mock function for the create method + const mockCreate = vi.fn().mockResolvedValue({ + id: 'chatcmpl-123', + object: 'chat.completion', + created: 1677858242, + model: 'gpt-4', + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: 'This is a test response', + tool_calls: [ + { + id: 'tool-call-1', + type: 'function', + function: { + name: 'testFunction', + arguments: '{"arg1":"value1"}', + }, + }, + ], + }, + finish_reason: 'stop', + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }, + }); + + // Return a mocked version of the OpenAI class + return { + default: class MockOpenAI { + constructor() { + // Constructor implementation + } + + chat = { + completions: { + create: mockCreate, + }, + }; + }, + }; +}); + +describe('OpenAIProvider', () => { + let provider: OpenAIProvider; + + beforeEach(() => { + // Set environment variable for testing + process.env.OPENAI_API_KEY = 'test-api-key'; + provider = new OpenAIProvider('gpt-4'); + }); + + it('should initialize with correct properties', () => { + expect(provider.name).toBe('openai'); + expect(provider.provider).toBe('openai.chat'); + expect(provider.model).toBe('gpt-4'); + }); + + it('should throw error if API key is missing', () => { + // Clear environment variable + const originalKey = process.env.OPENAI_API_KEY; + delete process.env.OPENAI_API_KEY; + + expect(() => new OpenAIProvider('gpt-4')).toThrow( + 'OpenAI API key is required', + ); + + // Restore environment variable + process.env.OPENAI_API_KEY = originalKey; + }); + + it('should generate text and handle tool calls', async () => { + const response = await provider.generateText({ + messages: [ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Hello, can you help me?' }, + ], + functions: [ + { + name: 'testFunction', + description: 'A test function', + parameters: { + type: 'object', + properties: { + arg1: { type: 'string' }, + }, + }, + }, + ], + }); + + expect(response.text).toBe('This is a test response'); + expect(response.toolCalls).toHaveLength(1); + + const toolCall = response.toolCalls[0]; + expect(toolCall).toBeDefined(); + expect(toolCall?.name).toBe('testFunction'); + expect(toolCall?.id).toBe('tool-call-1'); + expect(toolCall?.content).toBe('{"arg1":"value1"}'); + + // Check token usage + expect(response.tokenUsage).toBeInstanceOf(TokenUsage); + expect(response.tokenUsage.input).toBe(10); + expect(response.tokenUsage.output).toBe(20); + }); + + it('should format messages correctly', async () => { + await provider.generateText({ + messages: [ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there' }, + { + role: 'tool_use', + id: 'tool-1', + name: 'testTool', + content: '{"param":"value"}', + }, + { + role: 'tool_result', + tool_use_id: 'tool-1', + content: '{"result":"success"}', + is_error: false, + }, + ], + }); + + // Get the mock instance + const client = provider['client']; + const mockOpenAI = client?.chat?.completions + ?.create as unknown as ReturnType; + + // Check that messages were formatted correctly + expect(mockOpenAI).toHaveBeenCalled(); + + // Get the second call arguments (from this test) + const calledWith = mockOpenAI.mock.calls[1]?.[0] || {}; + + expect(calledWith.messages).toHaveLength(5); + + // We need to check each message individually to avoid TypeScript errors + const systemMessage = calledWith.messages[0]; + if ( + systemMessage && + typeof systemMessage === 'object' && + 'role' in systemMessage + ) { + expect(systemMessage.role).toBe('system'); + expect(systemMessage.content).toBe('You are a helpful assistant.'); + } + + const userMessage = calledWith.messages[1]; + if ( + userMessage && + typeof userMessage === 'object' && + 'role' in userMessage + ) { + expect(userMessage.role).toBe('user'); + expect(userMessage.content).toBe('Hello'); + } + + const assistantMessage = calledWith.messages[2]; + if ( + assistantMessage && + typeof assistantMessage === 'object' && + 'role' in assistantMessage + ) { + expect(assistantMessage.role).toBe('assistant'); + expect(assistantMessage.content).toBe('Hi there'); + } + + // Check tool_use formatting + const toolUseMessage = calledWith.messages[3]; + if ( + toolUseMessage && + typeof toolUseMessage === 'object' && + 'role' in toolUseMessage + ) { + expect(toolUseMessage.role).toBe('assistant'); + expect(toolUseMessage.content).toBe(null); + + if ( + 'tool_calls' in toolUseMessage && + Array.isArray(toolUseMessage.tool_calls) + ) { + expect(toolUseMessage.tool_calls.length).toBe(1); + const toolCall = toolUseMessage.tool_calls[0]; + if (toolCall && 'function' in toolCall) { + expect(toolCall.function.name).toBe('testTool'); + } + } + } + + // Check tool_result formatting + const toolResultMessage = calledWith.messages[4]; + if ( + toolResultMessage && + typeof toolResultMessage === 'object' && + 'role' in toolResultMessage + ) { + expect(toolResultMessage.role).toBe('tool'); + expect(toolResultMessage.content).toBe('{"result":"success"}'); + if ('tool_call_id' in toolResultMessage) { + expect(toolResultMessage.tool_call_id).toBe('tool-1'); + } + } + }); +}); diff --git a/packages/agent/src/core/llm/provider.ts b/packages/agent/src/core/llm/provider.ts index 379bbef..365bd94 100644 --- a/packages/agent/src/core/llm/provider.ts +++ b/packages/agent/src/core/llm/provider.ts @@ -3,6 +3,7 @@ */ import { AnthropicProvider } from './providers/anthropic.js'; +import { OpenAIProvider } from './providers/openai.js'; import { ProviderOptions, GenerateOptions, LLMResponse } from './types.js'; /** @@ -39,6 +40,7 @@ const providerFactories: Record< (model: string, options: ProviderOptions) => LLMProvider > = { anthropic: (model, options) => new AnthropicProvider(model, options), + openai: (model, options) => new OpenAIProvider(model, options), }; /** diff --git a/packages/agent/src/core/llm/providers/openai.ts b/packages/agent/src/core/llm/providers/openai.ts new file mode 100644 index 0000000..676f8a8 --- /dev/null +++ b/packages/agent/src/core/llm/providers/openai.ts @@ -0,0 +1,207 @@ +/** + * OpenAI provider implementation + */ +import OpenAI from 'openai'; + +import { TokenUsage } from '../../tokens.js'; +import { ToolCall } from '../../types'; +import { LLMProvider } from '../provider.js'; +import { + GenerateOptions, + LLMResponse, + Message, + ProviderOptions, + FunctionDefinition, +} from '../types.js'; + +import type { + ChatCompletionMessageParam, + ChatCompletionTool, +} from 'openai/resources/chat'; + +/** + * OpenAI-specific options + */ +export interface OpenAIOptions extends ProviderOptions { + apiKey?: string; + baseUrl?: string; + organization?: string; +} + +/** + * OpenAI provider implementation + */ +export class OpenAIProvider implements LLMProvider { + name: string = 'openai'; + provider: string = 'openai.chat'; + model: string; + private client: OpenAI; + private apiKey: string; + private baseUrl?: string; + private organization?: string; + + constructor(model: string, options: OpenAIOptions = {}) { + this.model = model; + this.apiKey = options.apiKey || process.env.OPENAI_API_KEY || ''; + this.baseUrl = options.baseUrl; + this.organization = options.organization || process.env.OPENAI_ORGANIZATION; + + if (!this.apiKey) { + throw new Error('OpenAI API key is required'); + } + + // Initialize OpenAI client + this.client = new OpenAI({ + apiKey: this.apiKey, + ...(this.baseUrl && { baseURL: this.baseUrl }), + ...(this.organization && { organization: this.organization }), + }); + } + + /** + * Generate text using OpenAI API + */ + async generateText(options: GenerateOptions): Promise { + const { + messages, + functions, + temperature = 0.7, + maxTokens, + stopSequences, + topP, + presencePenalty, + frequencyPenalty, + responseFormat, + } = options; + + // Format messages for OpenAI + const formattedMessages = this.formatMessages(messages); + + // Format functions for OpenAI + const tools = functions ? this.formatFunctions(functions) : undefined; + + try { + const requestOptions = { + model: this.model, + messages: formattedMessages, + temperature, + max_tokens: maxTokens, + stop: stopSequences, + top_p: topP, + presence_penalty: presencePenalty, + frequency_penalty: frequencyPenalty, + tools: tools, + response_format: + responseFormat === 'json_object' + ? { type: 'json_object' as const } + : undefined, + }; + + const response = + await this.client.chat.completions.create(requestOptions); + + // Extract content and tool calls + const message = response.choices[0]?.message; + const content = message?.content || ''; + + // Handle tool calls if present + const toolCalls: ToolCall[] = []; + if (message?.tool_calls) { + for (const tool of message.tool_calls) { + if (tool.type === 'function') { + toolCalls.push({ + id: tool.id, + name: tool.function.name, + content: tool.function.arguments, + }); + } + } + } + + // Create token usage + const tokenUsage = new TokenUsage(); + tokenUsage.input = response.usage?.prompt_tokens || 0; + tokenUsage.output = response.usage?.completion_tokens || 0; + + return { + text: content, + toolCalls, + tokenUsage, + }; + } catch (error) { + throw new Error(`Error calling OpenAI API: ${(error as Error).message}`); + } + } + + /** + * Format messages for OpenAI API + */ + private formatMessages(messages: Message[]): ChatCompletionMessageParam[] { + return messages.map((msg): ChatCompletionMessageParam => { + // Use switch for better type narrowing + switch (msg.role) { + case 'user': + return { + role: 'user', + content: msg.content, + }; + case 'system': + return { + role: 'system', + content: msg.content, + }; + case 'assistant': + return { + role: 'assistant', + content: msg.content, + }; + case 'tool_use': + // OpenAI doesn't have a direct equivalent to tool_use, + // so we'll include it as a function call in an assistant message + return { + role: 'assistant', + content: null, + tool_calls: [ + { + id: msg.id, + type: 'function' as const, + function: { + name: msg.name, + arguments: msg.content, + }, + }, + ], + }; + case 'tool_result': + // Tool results in OpenAI are represented as tool messages + return { + role: 'tool', + content: msg.content, + tool_call_id: msg.tool_use_id, + }; + default: + // For any other role, default to user message + return { + role: 'user', + content: 'Unknown message type', + }; + } + }); + } + + /** + * Format functions for OpenAI API + */ + private formatFunctions( + functions: FunctionDefinition[], + ): ChatCompletionTool[] { + return functions.map((fn) => ({ + type: 'function' as const, + function: { + name: fn.name, + description: fn.description, + parameters: fn.parameters, + }, + })); + } +} diff --git a/packages/cli/src/settings/config.ts b/packages/cli/src/settings/config.ts index bbe90e3..4113515 100644 --- a/packages/cli/src/settings/config.ts +++ b/packages/cli/src/settings/config.ts @@ -56,6 +56,7 @@ const defaultConfig = { tokenCache: true, // API keys (empty by default) ANTHROPIC_API_KEY: '', + OPENAI_API_KEY: '', }; export type Config = typeof defaultConfig; From 30b0807d4f3ecdd24f53b7ee4160645a4ed10444 Mon Sep 17 00:00:00 2001 From: Ben Houston Date: Wed, 12 Mar 2025 07:30:14 -0400 Subject: [PATCH 2/2] fix(openai): add OpenAI dependency to agent package and enable provider in config --- packages/agent/package.json | 1 + packages/agent/src/core/toolAgent/config.ts | 6 ++--- pnpm-lock.yaml | 30 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/packages/agent/package.json b/packages/agent/package.json index 3101acf..6511def 100644 --- a/packages/agent/package.json +++ b/packages/agent/package.json @@ -51,6 +51,7 @@ "chalk": "^5.4.1", "dotenv": "^16", "jsdom": "^26.0.0", + "openai": "^4.87.3", "playwright": "^1.50.1", "uuid": "^11", "zod": "^3.24.2", diff --git a/packages/agent/src/core/toolAgent/config.ts b/packages/agent/src/core/toolAgent/config.ts index 29737c9..7351ccb 100644 --- a/packages/agent/src/core/toolAgent/config.ts +++ b/packages/agent/src/core/toolAgent/config.ts @@ -8,7 +8,7 @@ import { ToolContext } from '../types'; /** * Available model providers */ -export type ModelProvider = 'anthropic'; +export type ModelProvider = 'anthropic' | 'openai'; /* | 'openai' | 'ollama' @@ -22,9 +22,9 @@ export function getModel(provider: ModelProvider, model: string): LLMProvider { switch (provider) { case 'anthropic': return createProvider('anthropic', model); - /*case 'openai': + case 'openai': return createProvider('openai', model); - case 'ollama': + /*case 'ollama': if (options?.ollamaBaseUrl) { return createProvider('ollama', model, { baseUrl: options.ollamaBaseUrl, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c9704ca..4fc736a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -108,6 +108,9 @@ importers: jsdom: specifier: ^26.0.0 version: 26.0.0 + openai: + specifier: ^4.87.3 + version: 4.87.3(encoding@0.1.13)(ws@8.18.1)(zod@3.24.2) playwright: specifier: ^1.50.1 version: 1.51.0 @@ -3117,6 +3120,18 @@ packages: resolution: {integrity: sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ==} engines: {node: '>=18'} + openai@4.87.3: + resolution: {integrity: sha512-d2D54fzMuBYTxMW8wcNmhT1rYKcTfMJ8t+4KjH2KtvYenygITiGBgHoIrzHwnDQWW+C5oCA+ikIR2jgPCFqcKQ==} + hasBin: true + peerDependencies: + ws: ^8.18.0 + zod: ^3.23.8 + peerDependenciesMeta: + ws: + optional: true + zod: + optional: true + optionator@0.9.4: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} engines: {node: '>= 0.8.0'} @@ -7456,6 +7471,21 @@ snapshots: dependencies: mimic-function: 5.0.1 + openai@4.87.3(encoding@0.1.13)(ws@8.18.1)(zod@3.24.2): + dependencies: + '@types/node': 18.19.80 + '@types/node-fetch': 2.6.12 + abort-controller: 3.0.0 + agentkeepalive: 4.6.0 + form-data-encoder: 1.7.2 + formdata-node: 4.4.1 + node-fetch: 2.7.0(encoding@0.1.13) + optionalDependencies: + ws: 8.18.1 + zod: 3.24.2 + transitivePeerDependencies: + - encoding + optionator@0.9.4: dependencies: deep-is: 0.1.4