Skip to content

Commit e3f8b6e

Browse files
authored
Inform the model about the console (#271)
1 parent 79115bf commit e3f8b6e

File tree

1 file changed

+67
-14
lines changed

1 file changed

+67
-14
lines changed

bin/chat.js

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import { tools } from './tools.js'
22

3-
const systemPrompt = 'You are a machine learning web application named "Hyperparam".'
4-
+ ' You assist users with building high quality ML models by introspecting on their training set data.'
3+
/** @type {'text' | 'tool'} */
4+
let outputMode = 'text' // default output mode
5+
6+
function systemPrompt() {
7+
return 'You are a machine learning web application named "Hyperparam" running on a CLI terminal.'
8+
+ '\nYou assist users with analyzing and exploring datasets, particularly in parquet format.'
59
+ ' The website and api are available at hyperparam.app.'
6-
+ ' Hyperparam uses LLMs to analyze their own training set.'
7-
+ ' It can generate the perplexity, entropy, and other metrics of the training set.'
8-
+ ' This allows users to find segments of their data set which are difficult to model.'
9-
+ ' This could be because the data is junk, or because the data requires deeper understanding.'
10-
+ ' This is essential for closing the loop on the ML lifecycle.'
11-
+ ' The quickest way to get started is to upload a dataset and start exploring.'
10+
+ ' The Hyperparam CLI tool can list and explore local parquet files.'
11+
+ '\nYou are on a terminal and can only output: text, emojis, terminal colors, and terminal formatting.'
12+
+ ' Don\'t add additional markdown or html formatting unless requested.'
13+
+ (process.stdout.isTTY ? ` The terminal width is ${process.stdout.columns} characters.` : '')
14+
}
1215
/** @type {Message} */
13-
const systemMessage = { role: 'system', content: systemPrompt }
16+
const systemMessage = { role: 'system', content: systemPrompt() }
1417

1518
const colors = {
1619
system: '\x1b[36m', // cyan
@@ -43,6 +46,7 @@ async function sendToServer(chatInput) {
4346
if (!reader) throw new Error('No response body')
4447
const decoder = new TextDecoder()
4548
let buffer = ''
49+
const write = writeWithColor()
4650

4751
while (true) {
4852
const { done, value } = await reader.read()
@@ -57,6 +61,11 @@ async function sendToServer(chatInput) {
5761
const chunk = JSON.parse(line)
5862
const { type, error } = chunk
5963
if (type === 'response.output_text.delta') {
64+
// text mode
65+
if (outputMode === 'tool') {
66+
write('\n')
67+
}
68+
outputMode = 'text'
6069
streamResponse.content += chunk.delta
6170
write(chunk.delta)
6271
} else if (error) {
@@ -93,7 +102,7 @@ async function sendMessages(messages) {
93102
const response = await sendToServer(chatInput)
94103
messages.push(response)
95104
// handle tool results
96-
if (response.tool_calls) {
105+
if (response.tool_calls?.length) {
97106
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<string> }[]} */
98107
const toolResults = []
99108
for (const toolCall of response.tool_calls) {
@@ -106,7 +115,11 @@ async function sendMessages(messages) {
106115
throw new Error(`Unknown tool: ${toolCall.function.name}`)
107116
}
108117
}
109-
write('\n')
118+
// tool mode
119+
if (outputMode === 'text') {
120+
write('\n')
121+
}
122+
outputMode = 'tool' // switch to tool output mode
110123
for (const toolResult of toolResults) {
111124
const { toolCall, tool } = toolResult
112125
try {
@@ -123,14 +136,14 @@ async function sendMessages(messages) {
123136
const pairs = entries.map(([key, value]) => `${key} = ${value}`)
124137
func += `(${pairs.join(', ')})`
125138
}
126-
127-
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
139+
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n')
128140
messages.push({ role: 'tool', content, tool_call_id: toolCall.id })
129141
} catch (error) {
130-
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal)
142+
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}`, colors.normal)
131143
messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id })
132144
}
133145
}
146+
134147
// send messages with tool results
135148
await sendMessages(messages)
136149
}
@@ -143,6 +156,45 @@ function write(...args) {
143156
args.forEach(s => process.stdout.write(s))
144157
}
145158

159+
/**
160+
* Handle streaming output, but buffer if needed to handle escape codes.
161+
* @returns {(...args: string[]) => void}
162+
*/
163+
function writeWithColor() {
164+
/** @type {string | undefined} */
165+
let buffer
166+
/**
167+
* @param {string} char
168+
*/
169+
function writeChar(char) {
170+
if (buffer === undefined && char !== '\\' && char !== '\x1b') {
171+
write(char)
172+
} else {
173+
buffer ??= ''
174+
buffer += char
175+
const isEscape = buffer.startsWith('\\x1b[') || buffer.startsWith('\\033[')
176+
// if the buffer is an escape sequence, write it
177+
if (isEscape) {
178+
// convert to terminal escape sequence
179+
const escaped = buffer.replace(/\\x1b\[/g, '\x1b[').replace(/\\033\[/g, '\x1b[')
180+
write(escaped)
181+
buffer = undefined
182+
} else if (buffer.length > 6) {
183+
// no match, just write it
184+
write(buffer)
185+
buffer = undefined
186+
}
187+
}
188+
}
189+
return (...args) => {
190+
for (const arg of args) {
191+
for (const char of arg) {
192+
writeChar(char)
193+
}
194+
}
195+
}
196+
}
197+
146198
export function chat() {
147199
/** @type {Message[]} */
148200
const messages = [systemMessage]
@@ -157,6 +209,7 @@ export function chat() {
157209
} else if (input) {
158210
try {
159211
write(colors.user, 'answer: ', colors.normal)
212+
outputMode = 'text' // switch to text output mode
160213
messages.push({ role: 'user', content: input.trim() })
161214
await sendMessages(messages)
162215
} catch (error) {

0 commit comments

Comments
 (0)