diff --git a/bin/chat.js b/bin/chat.js index 2789f8ee..f4be69c3 100644 --- a/bin/chat.js +++ b/bin/chat.js @@ -1,15 +1,14 @@ import { tools } from './tools.js' -const systemPrompt = - 'You are a machine learning web application named "hyperparam". ' + - 'You assist users with building high quality ML models by introspecting on their training set data. ' + - 'The website and api are available at hyperparam.app. ' + - 'Hyperparam uses LLMs to analyze their own training set. ' + - 'It can generate the perplexity, entropy, and other metrics of the training set. ' + - 'This allows users to find segments of their data set which are difficult to model. ' + - 'This could be because the data is junk, or because the data requires deeper understanding. ' + - 'This is essential for closing the loop on the ML lifecycle. ' + - 'The quickest way to get started is to upload a dataset and start exploring.' +const systemPrompt = 'You are a machine learning web application named "Hyperparam".' + + ' You assist users with building high quality ML models by introspecting on their training set data.' + + ' The website and api are available at hyperparam.app.' + + ' Hyperparam uses LLMs to analyze their own training set.' + + ' It can generate the perplexity, entropy, and other metrics of the training set.' + + ' This allows users to find segments of their data set which are difficult to model.' + + ' This could be because the data is junk, or because the data requires deeper understanding.' + + ' This is essential for closing the loop on the ML lifecycle.' + + ' The quickest way to get started is to upload a dataset and start exploring.' /** @type {Message} */ const systemMessage = { role: 'system', content: systemPrompt } @@ -27,7 +26,7 @@ const colors = { * @returns {Promise} */ async function sendToServer(chatInput) { - const response = await fetch('http://localhost:3000/api/functions/openai/chat', { + const response = await fetch('https://hyperparam.app/api/functions/openai/chat', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(chatInput), @@ -38,8 +37,10 @@ async function sendToServer(chatInput) { } // Process the streaming response + /** @type {Message} */ const streamResponse = { role: 'assistant', content: '' } - const reader = response.body.getReader() + const reader = response.body?.getReader() + if (!reader) throw new Error('No response body') const decoder = new TextDecoder() let buffer = '' @@ -53,19 +54,19 @@ async function sendToServer(chatInput) { for (const line of lines) { if (!line.trim()) continue try { - const jsonChunk = JSON.parse(line) - const { content, error } = jsonChunk - if (content) { - streamResponse.content += content - write(content) + const chunk = JSON.parse(line) + const { type, error } = chunk + if (type === 'response.output_text.delta') { + streamResponse.content += chunk.delta + write(chunk.delta) } else if (error) { console.error(error) throw new Error(error) - } else if (jsonChunk.function) { + } else if (chunk.function) { streamResponse.tool_calls ??= [] - streamResponse.tool_calls.push(jsonChunk) - } else if (!jsonChunk.key && jsonChunk.content !== '') { - console.log('Unknown chunk', jsonChunk) + streamResponse.tool_calls.push(chunk) + } else if (!chunk.key) { + console.log('Unknown chunk', chunk) } } catch (err) { console.error('Error parsing chunk', err) @@ -85,6 +86,7 @@ async function sendToServer(chatInput) { */ async function sendMessages(messages) { const chatInput = { + model: 'gpt-4o', messages, tools: tools.map(tool => tool.tool), } @@ -92,12 +94,13 @@ async function sendMessages(messages) { messages.push(response) // handle tool results if (response.tool_calls) { - /** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise }[]} */ + /** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise }[]} */ const toolResults = [] for (const toolCall of response.tool_calls) { const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name) if (tool) { - const result = tool.handleToolCall(toolCall) + const args = JSON.parse(toolCall.function?.arguments ?? '{}') + const result = tool.handleToolCall(args) toolResults.push({ toolCall, tool, result }) } else { throw new Error(`Unknown tool: ${toolCall.function.name}`) @@ -106,24 +109,27 @@ async function sendMessages(messages) { write('\n') for (const toolResult of toolResults) { const { toolCall, tool } = toolResult - const result = await toolResult.result + try { + const content = await toolResult.result - // Construct function call message - const args = JSON.parse(toolCall.function?.arguments ?? '{}') - const keys = Object.keys(args) - let func = toolCall.function.name - if (keys.length === 0) { - func += '()' - } else if (keys.length === 1) { - func += `(${args[keys[0]]})` - } else { - // transform to (arg1 = 111, arg2 = 222) - const pairs = keys.map(key => `${key} = ${args[key]}`) - func += `(${pairs.join(', ')})` - } + // Construct function call message + const args = JSON.parse(toolCall.function?.arguments ?? '{}') + const entries = Object.entries(args) + let func = toolCall.function.name + if (entries.length === 0) { + func += '()' + } else { + // transform to (arg1 = 111, arg2 = 222) + const pairs = entries.map(([key, value]) => `${key} = ${value}`) + func += `(${pairs.join(', ')})` + } - write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n') - messages.push(result) + write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n') + messages.push({ role: 'tool', content, tool_call_id: toolCall.id }) + } catch (error) { + write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal) + messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id }) + } } // send messages with tool results await sendMessages(messages) diff --git a/bin/serve.js b/bin/serve.js index 94a69789..a5f3abaf 100644 --- a/bin/serve.js +++ b/bin/serve.js @@ -123,7 +123,7 @@ async function handleStatic(filePath, range) { if (!stats?.isFile()) { return { status: 404, content: 'not found' } } - const contentLength = stats.size + const fileSize = stats.size // detect content type const extname = path.extname(filePath) @@ -133,8 +133,8 @@ async function handleStatic(filePath, range) { // ranged requests if (range) { const [unit, ranges] = range.split('=') - if (unit === 'bytes') { - const [start, end] = ranges.split('-').map(Number) + if (unit === 'bytes' && ranges) { + const [start = 0, end = fileSize] = ranges.split('-').map(Number) // convert fs.ReadStream to web stream const fsStream = createReadStream(filePath, { start, end }) @@ -151,7 +151,7 @@ async function handleStatic(filePath, range) { } const content = await fs.readFile(filePath) - return { status: 200, content, contentLength, contentType } + return { status: 200, content, contentLength: fileSize, contentType } } /** diff --git a/bin/tools.js b/bin/tools.js index 8dcd9067..3a89d258 100644 --- a/bin/tools.js +++ b/bin/tools.js @@ -1,8 +1,10 @@ import fs from 'fs/promises' import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet' +import { compressors } from 'hyparquet-compressors' +const fileLimit = 20 // limit to 20 files per page /** - * @import { Message, ToolCall, ToolHandler } from './types.d.ts' + * @import { ToolHandler } from './types.d.ts' * @type {ToolHandler[]} */ export const tools = [ @@ -12,29 +14,47 @@ export const tools = [ type: 'function', function: { name: 'list_files', - description: 'List the files in the current directory.', + description: `List the files in a directory. Files are listed recursively up to ${fileLimit} per page.`, parameters: { type: 'object', properties: { - path: { type: 'string', description: 'The path to list files from (optional).' }, + path: { + type: 'string', + description: 'The path to list files from. Optional, defaults to the current directory.', + }, + filetype: { + type: 'string', + description: 'Optional file type to filter by, e.g. "parquet", "csv". If not provided, all files are listed.', + }, + offset: { + type: 'number', + description: 'Skip offset number of files in the listing. Defaults to 0. Optional.', + }, }, }, }, }, /** - * @param {ToolCall} toolCall - * @returns {Promise} + * @param {Record} args + * @returns {Promise} */ - async handleToolCall(toolCall) { - let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}') + async handleToolCall({ path = '.', filetype, offset = 0 }) { + if (typeof path !== 'string') { + throw new Error('Expected path to be a string') + } if (path.includes('..') || path.includes('~')) { throw new Error('Invalid path: ' + path) } - // append to current directory - path = process.cwd() + '/' + path - // list files in the current directory - const filenames = await fs.readdir(path) - return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id } + if (typeof filetype !== 'undefined' && typeof filetype !== 'string') { + throw new Error('Expected filetype to be a string or undefined') + } + const start = validateInteger('offset', offset, 0) + // list files in the directory + const filenames = (await fs.readdir(path)) + .filter(key => !filetype || key.endsWith(`.${filetype}`)) // filter by file type if provided + const limited = filenames.slice(start, start + fileLimit) + const end = start + limited.length + return `Files ${start + 1}..${end} of ${filenames.length}:\n${limited.join('\n')}` }, }, { @@ -42,36 +62,78 @@ export const tools = [ tool: { type: 'function', function: { - name: 'read_parquet', - description: 'Read rows from a parquet file. Do not request more than 5 rows.', + name: 'parquet_get_rows', + description: 'Get up to 5 rows of data from a parquet file.', parameters: { type: 'object', properties: { - filename: { type: 'string', description: 'The name of the parquet file to read.' }, - rowStart: { type: 'integer', description: 'The start row index.' }, - rowEnd: { type: 'integer', description: 'The end row index.' }, - orderBy: { type: 'string', description: 'The column name to sort by.' }, + filename: { + type: 'string', + description: 'The name of the parquet file to read.', + }, + offset: { + type: 'number', + description: 'The starting row index to fetch (0-indexed).', + }, + limit: { + type: 'number', + description: 'The number of rows to fetch. Default 5. Maximum 5.', + }, + orderBy: { + type: 'string', + description: 'The column name to sort by.', + }, }, required: ['filename'], }, }, }, /** - * @param {ToolCall} toolCall - * @returns {Promise} + * @param {Record} args + * @returns {Promise} */ - async handleToolCall(toolCall) { - const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}') - if (rowEnd - rowStart > 5) { - throw new Error('Do NOT request more than 5 rows.') + async handleToolCall({ filename, offset = 0, limit = 5, orderBy }) { + if (typeof filename !== 'string') { + throw new Error('Expected filename to be a string') + } + const rowStart = validateInteger('offset', offset, 0) + const rowEnd = rowStart + validateInteger('limit', limit, 1, 5) + if (typeof orderBy !== 'undefined' && typeof orderBy !== 'string') { + throw new Error('Expected orderBy to be a string') } const file = await asyncBufferFromFile(filename) - const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy }) + const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy, compressors }) let content = '' for (let i = rowStart; i < rowEnd; i++) { - content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n` + content += `Row ${i}: ${stringify(rows[i])}\n` } - return { role: 'tool', content, tool_call_id: toolCall.id } + return content }, }, ] + +/** + * Validates that a value is an integer within the specified range. Max is inclusive. + * @param {string} name - The name of the value being validated. + * @param {unknown} value - The value to validate. + * @param {number} min - The minimum allowed value (inclusive). + * @param {number} [max] - The maximum allowed value (inclusive). + * @returns {number} + */ +function validateInteger(name, value, min, max) { + if (typeof value !== 'number' || isNaN(value)) { + throw new Error(`Invalid number for ${name}: ${value}`) + } + if (!Number.isInteger(value)) { + throw new Error(`Invalid number for ${name}: ${value}. Must be an integer.`) + } + if (value < min || max !== undefined && value > max) { + throw new Error(`Invalid number for ${name}: ${value}. Must be between ${min} and ${max}.`) + } + return value +} + +function stringify(obj, limit = 1000) { + const str = JSON.stringify(toJson(obj)) + return str.length <= limit ? str : str.slice(0, limit) + '…' +} diff --git a/bin/types.d.ts b/bin/types.d.ts index cff9a024..308dc526 100644 --- a/bin/types.d.ts +++ b/bin/types.d.ts @@ -20,7 +20,7 @@ export interface Message { export interface ToolHandler { emoji: string tool: Tool - handleToolCall(toolCall: ToolCall): Promise + handleToolCall(args: Record): Promise } interface ToolProperty { type: string diff --git a/tsconfig.json b/tsconfig.json index 3d90c573..8e8280ec 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,5 +1,7 @@ { "compilerOptions": { + "allowJs": true, + "checkJs": true, "target": "es2022", "useDefineForClassFields": true, "lib": ["es2022", "DOM", "DOM.Iterable"],