diff --git a/bin/chat.js b/bin/chat.js index bc4e7697..2789f8ee 100644 --- a/bin/chat.js +++ b/bin/chat.js @@ -1,6 +1,7 @@ -import http from 'http' // TODO: https +import { tools } from './tools.js' -const systemPrompt = 'You are a machine learning web application named "hyperparam". ' + +const systemPrompt = + 'You are a machine learning web application named "hyperparam". ' + 'You assist users with building high quality ML models by introspecting on their training set data. ' + 'The website and api are available at hyperparam.app. ' + 'Hyperparam uses LLMs to analyze their own training set. ' + @@ -9,49 +10,124 @@ const systemPrompt = 'You are a machine learning web application named "hyperpar 'This could be because the data is junk, or because the data requires deeper understanding. ' + 'This is essential for closing the loop on the ML lifecycle. ' + 'The quickest way to get started is to upload a dataset and start exploring.' -const messages = [{ role: 'system', content: systemPrompt }] +/** @type {Message} */ +const systemMessage = { role: 'system', content: systemPrompt } + +const colors = { + system: '\x1b[36m', // cyan + user: '\x1b[33m', // yellow + tool: '\x1b[90m', // gray + error: '\x1b[31m', // red + normal: '\x1b[0m', // reset +} /** + * @import { Message } from './types.d.ts' * @param {Object} chatInput - * @returns {Promise} + * @returns {Promise} */ -function sendToServer(chatInput) { - return new Promise((resolve, reject) => { - const json = JSON.stringify(chatInput) - const options = { - hostname: 'localhost', - port: 3000, - path: '/api/functions/openai/chat', - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Content-Length': json.length, - }, - } - const req = http.request(options, res => { - let responseBody = '' - res.on('data', chunk => { - if (chunk[0] !== 123) return // { - try { - const { content } = JSON.parse(chunk) - responseBody += content +async function sendToServer(chatInput) { + const response = await fetch('http://localhost:3000/api/functions/openai/chat', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(chatInput), + }) + + if (!response.ok) { + throw new Error(`Request failed: ${response.status}`) + } + + // Process the streaming response + const streamResponse = { role: 'assistant', content: '' } + const reader = response.body.getReader() + const decoder = new TextDecoder() + let buffer = '' + + while (true) { + const { done, value } = await reader.read() + if (done) break + buffer += decoder.decode(value, { stream: true }) + const lines = buffer.split('\n') + // Keep the last line in the buffer + buffer = lines.pop() || '' + for (const line of lines) { + if (!line.trim()) continue + try { + const jsonChunk = JSON.parse(line) + const { content, error } = jsonChunk + if (content) { + streamResponse.content += content write(content) - } catch (error) { - reject(error) + } else if (error) { + console.error(error) + throw new Error(error) + } else if (jsonChunk.function) { + streamResponse.tool_calls ??= [] + streamResponse.tool_calls.push(jsonChunk) + } else if (!jsonChunk.key && jsonChunk.content !== '') { + console.log('Unknown chunk', jsonChunk) } - }) - res.on('end', () => { - if (res.statusCode === 200) { - resolve(responseBody) - } else { - reject(new Error(`request failed: ${res.statusCode}`)) - } - }) - }) - req.on('error', reject) - req.write(json) - req.end() - }) + } catch (err) { + console.error('Error parsing chunk', err) + } + } + } + return streamResponse +} + +/** + * Send messages to the server and handle tool calls. + * Will mutate the messages array! + * + * @import { ToolCall, ToolHandler } from './types.d.ts' + * @param {Message[]} messages + * @returns {Promise} + */ +async function sendMessages(messages) { + const chatInput = { + messages, + tools: tools.map(tool => tool.tool), + } + const response = await sendToServer(chatInput) + messages.push(response) + // handle tool results + if (response.tool_calls) { + /** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise }[]} */ + const toolResults = [] + for (const toolCall of response.tool_calls) { + const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name) + if (tool) { + const result = tool.handleToolCall(toolCall) + toolResults.push({ toolCall, tool, result }) + } else { + throw new Error(`Unknown tool: ${toolCall.function.name}`) + } + } + write('\n') + for (const toolResult of toolResults) { + const { toolCall, tool } = toolResult + const result = await toolResult.result + + // Construct function call message + const args = JSON.parse(toolCall.function?.arguments ?? '{}') + const keys = Object.keys(args) + let func = toolCall.function.name + if (keys.length === 0) { + func += '()' + } else if (keys.length === 1) { + func += `(${args[keys[0]]})` + } else { + // transform to (arg1 = 111, arg2 = 222) + const pairs = keys.map(key => `${key} = ${args[key]}`) + func += `(${pairs.join(', ')})` + } + + write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n') + messages.push(result) + } + // send messages with tool results + await sendMessages(messages) + } } /** @@ -62,15 +138,10 @@ function write(...args) { } export function chat() { + /** @type {Message[]} */ + const messages = [systemMessage] process.stdin.setEncoding('utf-8') - const colors = { - system: '\x1b[36m', // cyan - user: '\x1b[33m', // yellow - error: '\x1b[31m', // red - normal: '\x1b[0m', // reset - } - write(colors.system, 'question: ', colors.normal) process.stdin.on('data', async (/** @type {string} */ input) => { @@ -81,8 +152,7 @@ export function chat() { try { write(colors.user, 'answer: ', colors.normal) messages.push({ role: 'user', content: input.trim() }) - const response = await sendToServer({ messages }) - messages.push({ role: 'assistant', content: response }) + await sendMessages(messages) } catch (error) { console.error(colors.error, '\n' + error) } finally { diff --git a/bin/cli.js b/bin/cli.js index bccbd653..ff0afcd8 100755 --- a/bin/cli.js +++ b/bin/cli.js @@ -49,7 +49,7 @@ if (arg === 'chat') { */ function checkForUpdates() { const currentVersion = packageJson.version - return void fetch('https://registry.npmjs.org/hyperparam/latest') + return fetch('https://registry.npmjs.org/hyperparam/latest') .then(response => response.json()) .then(data => { const latestVersion = data.version diff --git a/bin/tools.js b/bin/tools.js new file mode 100644 index 00000000..8dcd9067 --- /dev/null +++ b/bin/tools.js @@ -0,0 +1,77 @@ +import fs from 'fs/promises' +import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet' + +/** + * @import { Message, ToolCall, ToolHandler } from './types.d.ts' + * @type {ToolHandler[]} + */ +export const tools = [ + { + emoji: '📂', + tool: { + type: 'function', + function: { + name: 'list_files', + description: 'List the files in the current directory.', + parameters: { + type: 'object', + properties: { + path: { type: 'string', description: 'The path to list files from (optional).' }, + }, + }, + }, + }, + /** + * @param {ToolCall} toolCall + * @returns {Promise} + */ + async handleToolCall(toolCall) { + let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}') + if (path.includes('..') || path.includes('~')) { + throw new Error('Invalid path: ' + path) + } + // append to current directory + path = process.cwd() + '/' + path + // list files in the current directory + const filenames = await fs.readdir(path) + return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id } + }, + }, + { + emoji: '📄', + tool: { + type: 'function', + function: { + name: 'read_parquet', + description: 'Read rows from a parquet file. Do not request more than 5 rows.', + parameters: { + type: 'object', + properties: { + filename: { type: 'string', description: 'The name of the parquet file to read.' }, + rowStart: { type: 'integer', description: 'The start row index.' }, + rowEnd: { type: 'integer', description: 'The end row index.' }, + orderBy: { type: 'string', description: 'The column name to sort by.' }, + }, + required: ['filename'], + }, + }, + }, + /** + * @param {ToolCall} toolCall + * @returns {Promise} + */ + async handleToolCall(toolCall) { + const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}') + if (rowEnd - rowStart > 5) { + throw new Error('Do NOT request more than 5 rows.') + } + const file = await asyncBufferFromFile(filename) + const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy }) + let content = '' + for (let i = rowStart; i < rowEnd; i++) { + content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n` + } + return { role: 'tool', content, tool_call_id: toolCall.id } + }, + }, +] diff --git a/bin/types.d.ts b/bin/types.d.ts new file mode 100644 index 00000000..cff9a024 --- /dev/null +++ b/bin/types.d.ts @@ -0,0 +1,43 @@ +export interface ToolCall { + id: string + type: 'function' + function: { + name: string + arguments?: string + } +} + +export type Role = 'system' | 'user' | 'assistant' | 'tool' + +export interface Message { + role: Role + content: string + tool_calls?: ToolCall[] + tool_call_id?: string + error?: string +} + +export interface ToolHandler { + emoji: string + tool: Tool + handleToolCall(toolCall: ToolCall): Promise +} +interface ToolProperty { + type: string + description: string +} + +export interface Tool { + type: 'function' + function: { + name: string + description: string + parameters?: { + type: 'object' + properties: Record + required?: string[] + additionalProperties?: boolean + }, + strict?: boolean + } +} diff --git a/package.json b/package.json index 5a888f74..052cf56f 100644 --- a/package.json +++ b/package.json @@ -48,7 +48,7 @@ }, "dependencies": { "hightable": "0.12.1", - "hyparquet": "1.8.7", + "hyparquet": "1.9.0", "hyparquet-compressors": "1.0.0", "react": "18.3.1", "react-dom": "18.3.1" @@ -56,7 +56,7 @@ "devDependencies": { "@eslint/js": "9.22.0", "@testing-library/react": "16.2.0", - "@types/node": "22.13.9", + "@types/node": "22.13.10", "@types/react": "19.0.10", "@types/react-dom": "19.0.4", "@vitejs/plugin-react": "4.3.4",