diff --git a/bin/chat.js b/bin/chat.js index f4be69c3..181c9a31 100644 --- a/bin/chat.js +++ b/bin/chat.js @@ -1,16 +1,19 @@ import { tools } from './tools.js' -const systemPrompt = 'You are a machine learning web application named "Hyperparam".' - + ' You assist users with building high quality ML models by introspecting on their training set data.' +/** @type {'text' | 'tool'} */ +let outputMode = 'text' // default output mode + +function systemPrompt() { + return 'You are a machine learning web application named "Hyperparam" running on a CLI terminal.' + + '\nYou assist users with analyzing and exploring datasets, particularly in parquet format.' + ' The website and api are available at hyperparam.app.' - + ' Hyperparam uses LLMs to analyze their own training set.' - + ' It can generate the perplexity, entropy, and other metrics of the training set.' - + ' This allows users to find segments of their data set which are difficult to model.' - + ' This could be because the data is junk, or because the data requires deeper understanding.' - + ' This is essential for closing the loop on the ML lifecycle.' - + ' The quickest way to get started is to upload a dataset and start exploring.' + + ' The Hyperparam CLI tool can list and explore local parquet files.' + + '\nYou are on a terminal and can only output: text, emojis, terminal colors, and terminal formatting.' + + ' Don\'t add additional markdown or html formatting unless requested.' + + (process.stdout.isTTY ? ` The terminal width is ${process.stdout.columns} characters.` : '') +} /** @type {Message} */ -const systemMessage = { role: 'system', content: systemPrompt } +const systemMessage = { role: 'system', content: systemPrompt() } const colors = { system: '\x1b[36m', // cyan @@ -43,6 +46,7 @@ async function sendToServer(chatInput) { if (!reader) throw new Error('No response body') const decoder = new TextDecoder() let buffer = '' + const write = writeWithColor() while (true) { const { done, value } = await reader.read() @@ -57,6 +61,11 @@ async function sendToServer(chatInput) { const chunk = JSON.parse(line) const { type, error } = chunk if (type === 'response.output_text.delta') { + // text mode + if (outputMode === 'tool') { + write('\n') + } + outputMode = 'text' streamResponse.content += chunk.delta write(chunk.delta) } else if (error) { @@ -93,7 +102,7 @@ async function sendMessages(messages) { const response = await sendToServer(chatInput) messages.push(response) // handle tool results - if (response.tool_calls) { + if (response.tool_calls?.length) { /** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise }[]} */ const toolResults = [] for (const toolCall of response.tool_calls) { @@ -106,7 +115,11 @@ async function sendMessages(messages) { throw new Error(`Unknown tool: ${toolCall.function.name}`) } } - write('\n') + // tool mode + if (outputMode === 'text') { + write('\n') + } + outputMode = 'tool' // switch to tool output mode for (const toolResult of toolResults) { const { toolCall, tool } = toolResult try { @@ -123,14 +136,14 @@ async function sendMessages(messages) { const pairs = entries.map(([key, value]) => `${key} = ${value}`) func += `(${pairs.join(', ')})` } - - write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n') + write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n') messages.push({ role: 'tool', content, tool_call_id: toolCall.id }) } catch (error) { - write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal) + write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}`, colors.normal) messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id }) } } + // send messages with tool results await sendMessages(messages) } @@ -143,6 +156,45 @@ function write(...args) { args.forEach(s => process.stdout.write(s)) } +/** + * Handle streaming output, but buffer if needed to handle escape codes. + * @returns {(...args: string[]) => void} + */ +function writeWithColor() { + /** @type {string | undefined} */ + let buffer + /** + * @param {string} char + */ + function writeChar(char) { + if (buffer === undefined && char !== '\\' && char !== '\x1b') { + write(char) + } else { + buffer ??= '' + buffer += char + const isEscape = buffer.startsWith('\\x1b[') || buffer.startsWith('\\033[') + // if the buffer is an escape sequence, write it + if (isEscape) { + // convert to terminal escape sequence + const escaped = buffer.replace(/\\x1b\[/g, '\x1b[').replace(/\\033\[/g, '\x1b[') + write(escaped) + buffer = undefined + } else if (buffer.length > 6) { + // no match, just write it + write(buffer) + buffer = undefined + } + } + } + return (...args) => { + for (const arg of args) { + for (const char of arg) { + writeChar(char) + } + } + } +} + export function chat() { /** @type {Message[]} */ const messages = [systemMessage] @@ -157,6 +209,7 @@ export function chat() { } else if (input) { try { write(colors.user, 'answer: ', colors.normal) + outputMode = 'text' // switch to text output mode messages.push({ role: 'user', content: input.trim() }) await sendMessages(messages) } catch (error) {