Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions bin/chat.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import { tools } from './tools.js'

const systemPrompt = 'You are a machine learning web application named "Hyperparam".'
+ ' You assist users with building high quality ML models by introspecting on their training set data.'
/** @type {'text' | 'tool'} */
let outputMode = 'text' // default output mode

function systemPrompt() {
return 'You are a machine learning web application named "Hyperparam" running on a CLI terminal.'
+ '\nYou assist users with analyzing and exploring datasets, particularly in parquet format.'
+ ' The website and api are available at hyperparam.app.'
+ ' Hyperparam uses LLMs to analyze their own training set.'
+ ' It can generate the perplexity, entropy, and other metrics of the training set.'
+ ' This allows users to find segments of their data set which are difficult to model.'
+ ' This could be because the data is junk, or because the data requires deeper understanding.'
+ ' This is essential for closing the loop on the ML lifecycle.'
+ ' The quickest way to get started is to upload a dataset and start exploring.'
+ ' The Hyperparam CLI tool can list and explore local parquet files.'
+ '\nYou are on a terminal and can only output: text, emojis, terminal colors, and terminal formatting.'
+ ' Don\'t add additional markdown or html formatting unless requested.'
+ (process.stdout.isTTY ? ` The terminal width is ${process.stdout.columns} characters.` : '')
}
/** @type {Message} */
const systemMessage = { role: 'system', content: systemPrompt }
const systemMessage = { role: 'system', content: systemPrompt() }

const colors = {
system: '\x1b[36m', // cyan
Expand Down Expand Up @@ -43,6 +46,7 @@ async function sendToServer(chatInput) {
if (!reader) throw new Error('No response body')
const decoder = new TextDecoder()
let buffer = ''
const write = writeWithColor()

while (true) {
const { done, value } = await reader.read()
Expand All @@ -57,6 +61,11 @@ async function sendToServer(chatInput) {
const chunk = JSON.parse(line)
const { type, error } = chunk
if (type === 'response.output_text.delta') {
// text mode
if (outputMode === 'tool') {
write('\n')
}
outputMode = 'text'
streamResponse.content += chunk.delta
write(chunk.delta)
} else if (error) {
Expand Down Expand Up @@ -93,7 +102,7 @@ async function sendMessages(messages) {
const response = await sendToServer(chatInput)
messages.push(response)
// handle tool results
if (response.tool_calls) {
if (response.tool_calls?.length) {
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<string> }[]} */
const toolResults = []
for (const toolCall of response.tool_calls) {
Expand All @@ -106,7 +115,11 @@ async function sendMessages(messages) {
throw new Error(`Unknown tool: ${toolCall.function.name}`)
}
}
write('\n')
// tool mode
if (outputMode === 'text') {
write('\n')
}
outputMode = 'tool' // switch to tool output mode
for (const toolResult of toolResults) {
const { toolCall, tool } = toolResult
try {
Expand All @@ -123,14 +136,14 @@ async function sendMessages(messages) {
const pairs = entries.map(([key, value]) => `${key} = ${value}`)
func += `(${pairs.join(', ')})`
}

write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n')
messages.push({ role: 'tool', content, tool_call_id: toolCall.id })
} catch (error) {
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal)
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}`, colors.normal)
messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id })
}
}

// send messages with tool results
await sendMessages(messages)
}
Expand All @@ -143,6 +156,45 @@ function write(...args) {
args.forEach(s => process.stdout.write(s))
}

/**
* Handle streaming output, but buffer if needed to handle escape codes.
* @returns {(...args: string[]) => void}
*/
function writeWithColor() {
/** @type {string | undefined} */
let buffer
/**
* @param {string} char
*/
function writeChar(char) {
if (buffer === undefined && char !== '\\' && char !== '\x1b') {
write(char)
} else {
buffer ??= ''
buffer += char
const isEscape = buffer.startsWith('\\x1b[') || buffer.startsWith('\\033[')
// if the buffer is an escape sequence, write it
if (isEscape) {
// convert to terminal escape sequence
const escaped = buffer.replace(/\\x1b\[/g, '\x1b[').replace(/\\033\[/g, '\x1b[')
write(escaped)
buffer = undefined
} else if (buffer.length > 6) {
// no match, just write it
write(buffer)
buffer = undefined
}
}
}
return (...args) => {
for (const arg of args) {
for (const char of arg) {
writeChar(char)
}
}
}
}

export function chat() {
/** @type {Message[]} */
const messages = [systemMessage]
Expand All @@ -157,6 +209,7 @@ export function chat() {
} else if (input) {
try {
write(colors.user, 'answer: ', colors.normal)
outputMode = 'text' // switch to text output mode
messages.push({ role: 'user', content: input.trim() })
await sendMessages(messages)
} catch (error) {
Expand Down