Skip to content

Commit 960ef9c

Browse files
authored
Add tool calls to cli chat (#178)
1 parent b9d46b2 commit 960ef9c

File tree

5 files changed

+241
-51
lines changed

5 files changed

+241
-51
lines changed

bin/chat.js

Lines changed: 118 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import http from 'http' // TODO: https
1+
import { tools } from './tools.js'
22

3-
const systemPrompt = 'You are a machine learning web application named "hyperparam". ' +
3+
const systemPrompt =
4+
'You are a machine learning web application named "hyperparam". ' +
45
'You assist users with building high quality ML models by introspecting on their training set data. ' +
56
'The website and api are available at hyperparam.app. ' +
67
'Hyperparam uses LLMs to analyze their own training set. ' +
@@ -9,49 +10,124 @@ const systemPrompt = 'You are a machine learning web application named "hyperpar
910
'This could be because the data is junk, or because the data requires deeper understanding. ' +
1011
'This is essential for closing the loop on the ML lifecycle. ' +
1112
'The quickest way to get started is to upload a dataset and start exploring.'
12-
const messages = [{ role: 'system', content: systemPrompt }]
13+
/** @type {Message} */
14+
const systemMessage = { role: 'system', content: systemPrompt }
15+
16+
const colors = {
17+
system: '\x1b[36m', // cyan
18+
user: '\x1b[33m', // yellow
19+
tool: '\x1b[90m', // gray
20+
error: '\x1b[31m', // red
21+
normal: '\x1b[0m', // reset
22+
}
1323

1424
/**
25+
* @import { Message } from './types.d.ts'
1526
* @param {Object} chatInput
16-
* @returns {Promise<string>}
27+
* @returns {Promise<Message>}
1728
*/
18-
function sendToServer(chatInput) {
19-
return new Promise((resolve, reject) => {
20-
const json = JSON.stringify(chatInput)
21-
const options = {
22-
hostname: 'localhost',
23-
port: 3000,
24-
path: '/api/functions/openai/chat',
25-
method: 'POST',
26-
headers: {
27-
'Content-Type': 'application/json',
28-
'Content-Length': json.length,
29-
},
30-
}
31-
const req = http.request(options, res => {
32-
let responseBody = ''
33-
res.on('data', chunk => {
34-
if (chunk[0] !== 123) return // {
35-
try {
36-
const { content } = JSON.parse(chunk)
37-
responseBody += content
29+
async function sendToServer(chatInput) {
30+
const response = await fetch('http://localhost:3000/api/functions/openai/chat', {
31+
method: 'POST',
32+
headers: { 'Content-Type': 'application/json' },
33+
body: JSON.stringify(chatInput),
34+
})
35+
36+
if (!response.ok) {
37+
throw new Error(`Request failed: ${response.status}`)
38+
}
39+
40+
// Process the streaming response
41+
const streamResponse = { role: 'assistant', content: '' }
42+
const reader = response.body.getReader()
43+
const decoder = new TextDecoder()
44+
let buffer = ''
45+
46+
while (true) {
47+
const { done, value } = await reader.read()
48+
if (done) break
49+
buffer += decoder.decode(value, { stream: true })
50+
const lines = buffer.split('\n')
51+
// Keep the last line in the buffer
52+
buffer = lines.pop() || ''
53+
for (const line of lines) {
54+
if (!line.trim()) continue
55+
try {
56+
const jsonChunk = JSON.parse(line)
57+
const { content, error } = jsonChunk
58+
if (content) {
59+
streamResponse.content += content
3860
write(content)
39-
} catch (error) {
40-
reject(error)
61+
} else if (error) {
62+
console.error(error)
63+
throw new Error(error)
64+
} else if (jsonChunk.function) {
65+
streamResponse.tool_calls ??= []
66+
streamResponse.tool_calls.push(jsonChunk)
67+
} else if (!jsonChunk.key && jsonChunk.content !== '') {
68+
console.log('Unknown chunk', jsonChunk)
4169
}
42-
})
43-
res.on('end', () => {
44-
if (res.statusCode === 200) {
45-
resolve(responseBody)
46-
} else {
47-
reject(new Error(`request failed: ${res.statusCode}`))
48-
}
49-
})
50-
})
51-
req.on('error', reject)
52-
req.write(json)
53-
req.end()
54-
})
70+
} catch (err) {
71+
console.error('Error parsing chunk', err)
72+
}
73+
}
74+
}
75+
return streamResponse
76+
}
77+
78+
/**
79+
* Send messages to the server and handle tool calls.
80+
* Will mutate the messages array!
81+
*
82+
* @import { ToolCall, ToolHandler } from './types.d.ts'
83+
* @param {Message[]} messages
84+
* @returns {Promise<void>}
85+
*/
86+
async function sendMessages(messages) {
87+
const chatInput = {
88+
messages,
89+
tools: tools.map(tool => tool.tool),
90+
}
91+
const response = await sendToServer(chatInput)
92+
messages.push(response)
93+
// handle tool results
94+
if (response.tool_calls) {
95+
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<Message> }[]} */
96+
const toolResults = []
97+
for (const toolCall of response.tool_calls) {
98+
const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name)
99+
if (tool) {
100+
const result = tool.handleToolCall(toolCall)
101+
toolResults.push({ toolCall, tool, result })
102+
} else {
103+
throw new Error(`Unknown tool: ${toolCall.function.name}`)
104+
}
105+
}
106+
write('\n')
107+
for (const toolResult of toolResults) {
108+
const { toolCall, tool } = toolResult
109+
const result = await toolResult.result
110+
111+
// Construct function call message
112+
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
113+
const keys = Object.keys(args)
114+
let func = toolCall.function.name
115+
if (keys.length === 0) {
116+
func += '()'
117+
} else if (keys.length === 1) {
118+
func += `(${args[keys[0]]})`
119+
} else {
120+
// transform to (arg1 = 111, arg2 = 222)
121+
const pairs = keys.map(key => `${key} = ${args[key]}`)
122+
func += `(${pairs.join(', ')})`
123+
}
124+
125+
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
126+
messages.push(result)
127+
}
128+
// send messages with tool results
129+
await sendMessages(messages)
130+
}
55131
}
56132

57133
/**
@@ -62,15 +138,10 @@ function write(...args) {
62138
}
63139

64140
export function chat() {
141+
/** @type {Message[]} */
142+
const messages = [systemMessage]
65143
process.stdin.setEncoding('utf-8')
66144

67-
const colors = {
68-
system: '\x1b[36m', // cyan
69-
user: '\x1b[33m', // yellow
70-
error: '\x1b[31m', // red
71-
normal: '\x1b[0m', // reset
72-
}
73-
74145
write(colors.system, 'question: ', colors.normal)
75146

76147
process.stdin.on('data', async (/** @type {string} */ input) => {
@@ -81,8 +152,7 @@ export function chat() {
81152
try {
82153
write(colors.user, 'answer: ', colors.normal)
83154
messages.push({ role: 'user', content: input.trim() })
84-
const response = await sendToServer({ messages })
85-
messages.push({ role: 'assistant', content: response })
155+
await sendMessages(messages)
86156
} catch (error) {
87157
console.error(colors.error, '\n' + error)
88158
} finally {

bin/cli.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ if (arg === 'chat') {
4949
*/
5050
function checkForUpdates() {
5151
const currentVersion = packageJson.version
52-
return void fetch('https://registry.npmjs.org/hyperparam/latest')
52+
return fetch('https://registry.npmjs.org/hyperparam/latest')
5353
.then(response => response.json())
5454
.then(data => {
5555
const latestVersion = data.version

bin/tools.js

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import fs from 'fs/promises'
2+
import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet'
3+
4+
/**
5+
* @import { Message, ToolCall, ToolHandler } from './types.d.ts'
6+
* @type {ToolHandler[]}
7+
*/
8+
export const tools = [
9+
{
10+
emoji: '📂',
11+
tool: {
12+
type: 'function',
13+
function: {
14+
name: 'list_files',
15+
description: 'List the files in the current directory.',
16+
parameters: {
17+
type: 'object',
18+
properties: {
19+
path: { type: 'string', description: 'The path to list files from (optional).' },
20+
},
21+
},
22+
},
23+
},
24+
/**
25+
* @param {ToolCall} toolCall
26+
* @returns {Promise<Message>}
27+
*/
28+
async handleToolCall(toolCall) {
29+
let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}')
30+
if (path.includes('..') || path.includes('~')) {
31+
throw new Error('Invalid path: ' + path)
32+
}
33+
// append to current directory
34+
path = process.cwd() + '/' + path
35+
// list files in the current directory
36+
const filenames = await fs.readdir(path)
37+
return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id }
38+
},
39+
},
40+
{
41+
emoji: '📄',
42+
tool: {
43+
type: 'function',
44+
function: {
45+
name: 'read_parquet',
46+
description: 'Read rows from a parquet file. Do not request more than 5 rows.',
47+
parameters: {
48+
type: 'object',
49+
properties: {
50+
filename: { type: 'string', description: 'The name of the parquet file to read.' },
51+
rowStart: { type: 'integer', description: 'The start row index.' },
52+
rowEnd: { type: 'integer', description: 'The end row index.' },
53+
orderBy: { type: 'string', description: 'The column name to sort by.' },
54+
},
55+
required: ['filename'],
56+
},
57+
},
58+
},
59+
/**
60+
* @param {ToolCall} toolCall
61+
* @returns {Promise<Message>}
62+
*/
63+
async handleToolCall(toolCall) {
64+
const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}')
65+
if (rowEnd - rowStart > 5) {
66+
throw new Error('Do NOT request more than 5 rows.')
67+
}
68+
const file = await asyncBufferFromFile(filename)
69+
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy })
70+
let content = ''
71+
for (let i = rowStart; i < rowEnd; i++) {
72+
content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n`
73+
}
74+
return { role: 'tool', content, tool_call_id: toolCall.id }
75+
},
76+
},
77+
]

bin/types.d.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
export interface ToolCall {
2+
id: string
3+
type: 'function'
4+
function: {
5+
name: string
6+
arguments?: string
7+
}
8+
}
9+
10+
export type Role = 'system' | 'user' | 'assistant' | 'tool'
11+
12+
export interface Message {
13+
role: Role
14+
content: string
15+
tool_calls?: ToolCall[]
16+
tool_call_id?: string
17+
error?: string
18+
}
19+
20+
export interface ToolHandler {
21+
emoji: string
22+
tool: Tool
23+
handleToolCall(toolCall: ToolCall): Promise<Message>
24+
}
25+
interface ToolProperty {
26+
type: string
27+
description: string
28+
}
29+
30+
export interface Tool {
31+
type: 'function'
32+
function: {
33+
name: string
34+
description: string
35+
parameters?: {
36+
type: 'object'
37+
properties: Record<string, ToolProperty>
38+
required?: string[]
39+
additionalProperties?: boolean
40+
},
41+
strict?: boolean
42+
}
43+
}

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@
4848
},
4949
"dependencies": {
5050
"hightable": "0.12.1",
51-
"hyparquet": "1.8.7",
51+
"hyparquet": "1.9.0",
5252
"hyparquet-compressors": "1.0.0",
5353
"react": "18.3.1",
5454
"react-dom": "18.3.1"
5555
},
5656
"devDependencies": {
5757
"@eslint/js": "9.22.0",
5858
"@testing-library/react": "16.2.0",
59-
"@types/node": "22.13.9",
59+
"@types/node": "22.13.10",
6060
"@types/react": "19.0.10",
6161
"@types/react-dom": "19.0.4",
6262
"@vitejs/plugin-react": "4.3.4",

0 commit comments

Comments
 (0)