Skip to content

Commit 5813c53

Browse files
authored
Update chat tools (#270)
1 parent ed39b4f commit 5813c53

File tree

5 files changed

+141
-71
lines changed

5 files changed

+141
-71
lines changed

bin/chat.js

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import { tools } from './tools.js'
22

3-
const systemPrompt =
4-
'You are a machine learning web application named "hyperparam". ' +
5-
'You assist users with building high quality ML models by introspecting on their training set data. ' +
6-
'The website and api are available at hyperparam.app. ' +
7-
'Hyperparam uses LLMs to analyze their own training set. ' +
8-
'It can generate the perplexity, entropy, and other metrics of the training set. ' +
9-
'This allows users to find segments of their data set which are difficult to model. ' +
10-
'This could be because the data is junk, or because the data requires deeper understanding. ' +
11-
'This is essential for closing the loop on the ML lifecycle. ' +
12-
'The quickest way to get started is to upload a dataset and start exploring.'
3+
const systemPrompt = 'You are a machine learning web application named "Hyperparam".'
4+
+ ' You assist users with building high quality ML models by introspecting on their training set data.'
5+
+ ' The website and api are available at hyperparam.app.'
6+
+ ' Hyperparam uses LLMs to analyze their own training set.'
7+
+ ' It can generate the perplexity, entropy, and other metrics of the training set.'
8+
+ ' This allows users to find segments of their data set which are difficult to model.'
9+
+ ' This could be because the data is junk, or because the data requires deeper understanding.'
10+
+ ' This is essential for closing the loop on the ML lifecycle.'
11+
+ ' The quickest way to get started is to upload a dataset and start exploring.'
1312
/** @type {Message} */
1413
const systemMessage = { role: 'system', content: systemPrompt }
1514

@@ -27,7 +26,7 @@ const colors = {
2726
* @returns {Promise<Message>}
2827
*/
2928
async function sendToServer(chatInput) {
30-
const response = await fetch('http://localhost:3000/api/functions/openai/chat', {
29+
const response = await fetch('https://hyperparam.app/api/functions/openai/chat', {
3130
method: 'POST',
3231
headers: { 'Content-Type': 'application/json' },
3332
body: JSON.stringify(chatInput),
@@ -38,8 +37,10 @@ async function sendToServer(chatInput) {
3837
}
3938

4039
// Process the streaming response
40+
/** @type {Message} */
4141
const streamResponse = { role: 'assistant', content: '' }
42-
const reader = response.body.getReader()
42+
const reader = response.body?.getReader()
43+
if (!reader) throw new Error('No response body')
4344
const decoder = new TextDecoder()
4445
let buffer = ''
4546

@@ -53,19 +54,19 @@ async function sendToServer(chatInput) {
5354
for (const line of lines) {
5455
if (!line.trim()) continue
5556
try {
56-
const jsonChunk = JSON.parse(line)
57-
const { content, error } = jsonChunk
58-
if (content) {
59-
streamResponse.content += content
60-
write(content)
57+
const chunk = JSON.parse(line)
58+
const { type, error } = chunk
59+
if (type === 'response.output_text.delta') {
60+
streamResponse.content += chunk.delta
61+
write(chunk.delta)
6162
} else if (error) {
6263
console.error(error)
6364
throw new Error(error)
64-
} else if (jsonChunk.function) {
65+
} else if (chunk.function) {
6566
streamResponse.tool_calls ??= []
66-
streamResponse.tool_calls.push(jsonChunk)
67-
} else if (!jsonChunk.key && jsonChunk.content !== '') {
68-
console.log('Unknown chunk', jsonChunk)
67+
streamResponse.tool_calls.push(chunk)
68+
} else if (!chunk.key) {
69+
console.log('Unknown chunk', chunk)
6970
}
7071
} catch (err) {
7172
console.error('Error parsing chunk', err)
@@ -85,19 +86,21 @@ async function sendToServer(chatInput) {
8586
*/
8687
async function sendMessages(messages) {
8788
const chatInput = {
89+
model: 'gpt-4o',
8890
messages,
8991
tools: tools.map(tool => tool.tool),
9092
}
9193
const response = await sendToServer(chatInput)
9294
messages.push(response)
9395
// handle tool results
9496
if (response.tool_calls) {
95-
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<Message> }[]} */
97+
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<string> }[]} */
9698
const toolResults = []
9799
for (const toolCall of response.tool_calls) {
98100
const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name)
99101
if (tool) {
100-
const result = tool.handleToolCall(toolCall)
102+
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
103+
const result = tool.handleToolCall(args)
101104
toolResults.push({ toolCall, tool, result })
102105
} else {
103106
throw new Error(`Unknown tool: ${toolCall.function.name}`)
@@ -106,24 +109,27 @@ async function sendMessages(messages) {
106109
write('\n')
107110
for (const toolResult of toolResults) {
108111
const { toolCall, tool } = toolResult
109-
const result = await toolResult.result
112+
try {
113+
const content = await toolResult.result
110114

111-
// Construct function call message
112-
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
113-
const keys = Object.keys(args)
114-
let func = toolCall.function.name
115-
if (keys.length === 0) {
116-
func += '()'
117-
} else if (keys.length === 1) {
118-
func += `(${args[keys[0]]})`
119-
} else {
120-
// transform to (arg1 = 111, arg2 = 222)
121-
const pairs = keys.map(key => `${key} = ${args[key]}`)
122-
func += `(${pairs.join(', ')})`
123-
}
115+
// Construct function call message
116+
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
117+
const entries = Object.entries(args)
118+
let func = toolCall.function.name
119+
if (entries.length === 0) {
120+
func += '()'
121+
} else {
122+
// transform to (arg1 = 111, arg2 = 222)
123+
const pairs = entries.map(([key, value]) => `${key} = ${value}`)
124+
func += `(${pairs.join(', ')})`
125+
}
124126

125-
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
126-
messages.push(result)
127+
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
128+
messages.push({ role: 'tool', content, tool_call_id: toolCall.id })
129+
} catch (error) {
130+
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal)
131+
messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id })
132+
}
127133
}
128134
// send messages with tool results
129135
await sendMessages(messages)

bin/serve.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async function handleStatic(filePath, range) {
123123
if (!stats?.isFile()) {
124124
return { status: 404, content: 'not found' }
125125
}
126-
const contentLength = stats.size
126+
const fileSize = stats.size
127127

128128
// detect content type
129129
const extname = path.extname(filePath)
@@ -133,8 +133,8 @@ async function handleStatic(filePath, range) {
133133
// ranged requests
134134
if (range) {
135135
const [unit, ranges] = range.split('=')
136-
if (unit === 'bytes') {
137-
const [start, end] = ranges.split('-').map(Number)
136+
if (unit === 'bytes' && ranges) {
137+
const [start = 0, end = fileSize] = ranges.split('-').map(Number)
138138

139139
// convert fs.ReadStream to web stream
140140
const fsStream = createReadStream(filePath, { start, end })
@@ -151,7 +151,7 @@ async function handleStatic(filePath, range) {
151151
}
152152

153153
const content = await fs.readFile(filePath)
154-
return { status: 200, content, contentLength, contentType }
154+
return { status: 200, content, contentLength: fileSize, contentType }
155155
}
156156

157157
/**

bin/tools.js

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import fs from 'fs/promises'
22
import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet'
3+
import { compressors } from 'hyparquet-compressors'
34

5+
const fileLimit = 20 // limit to 20 files per page
46
/**
5-
* @import { Message, ToolCall, ToolHandler } from './types.d.ts'
7+
* @import { ToolHandler } from './types.d.ts'
68
* @type {ToolHandler[]}
79
*/
810
export const tools = [
@@ -12,66 +14,126 @@ export const tools = [
1214
type: 'function',
1315
function: {
1416
name: 'list_files',
15-
description: 'List the files in the current directory.',
17+
description: `List the files in a directory. Files are listed recursively up to ${fileLimit} per page.`,
1618
parameters: {
1719
type: 'object',
1820
properties: {
19-
path: { type: 'string', description: 'The path to list files from (optional).' },
21+
path: {
22+
type: 'string',
23+
description: 'The path to list files from. Optional, defaults to the current directory.',
24+
},
25+
filetype: {
26+
type: 'string',
27+
description: 'Optional file type to filter by, e.g. "parquet", "csv". If not provided, all files are listed.',
28+
},
29+
offset: {
30+
type: 'number',
31+
description: 'Skip offset number of files in the listing. Defaults to 0. Optional.',
32+
},
2033
},
2134
},
2235
},
2336
},
2437
/**
25-
* @param {ToolCall} toolCall
26-
* @returns {Promise<Message>}
38+
* @param {Record<string, unknown>} args
39+
* @returns {Promise<string>}
2740
*/
28-
async handleToolCall(toolCall) {
29-
let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}')
41+
async handleToolCall({ path = '.', filetype, offset = 0 }) {
42+
if (typeof path !== 'string') {
43+
throw new Error('Expected path to be a string')
44+
}
3045
if (path.includes('..') || path.includes('~')) {
3146
throw new Error('Invalid path: ' + path)
3247
}
33-
// append to current directory
34-
path = process.cwd() + '/' + path
35-
// list files in the current directory
36-
const filenames = await fs.readdir(path)
37-
return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id }
48+
if (typeof filetype !== 'undefined' && typeof filetype !== 'string') {
49+
throw new Error('Expected filetype to be a string or undefined')
50+
}
51+
const start = validateInteger('offset', offset, 0)
52+
// list files in the directory
53+
const filenames = (await fs.readdir(path))
54+
.filter(key => !filetype || key.endsWith(`.${filetype}`)) // filter by file type if provided
55+
const limited = filenames.slice(start, start + fileLimit)
56+
const end = start + limited.length
57+
return `Files ${start + 1}..${end} of ${filenames.length}:\n${limited.join('\n')}`
3858
},
3959
},
4060
{
4161
emoji: '📄',
4262
tool: {
4363
type: 'function',
4464
function: {
45-
name: 'read_parquet',
46-
description: 'Read rows from a parquet file. Do not request more than 5 rows.',
65+
name: 'parquet_get_rows',
66+
description: 'Get up to 5 rows of data from a parquet file.',
4767
parameters: {
4868
type: 'object',
4969
properties: {
50-
filename: { type: 'string', description: 'The name of the parquet file to read.' },
51-
rowStart: { type: 'integer', description: 'The start row index.' },
52-
rowEnd: { type: 'integer', description: 'The end row index.' },
53-
orderBy: { type: 'string', description: 'The column name to sort by.' },
70+
filename: {
71+
type: 'string',
72+
description: 'The name of the parquet file to read.',
73+
},
74+
offset: {
75+
type: 'number',
76+
description: 'The starting row index to fetch (0-indexed).',
77+
},
78+
limit: {
79+
type: 'number',
80+
description: 'The number of rows to fetch. Default 5. Maximum 5.',
81+
},
82+
orderBy: {
83+
type: 'string',
84+
description: 'The column name to sort by.',
85+
},
5486
},
5587
required: ['filename'],
5688
},
5789
},
5890
},
5991
/**
60-
* @param {ToolCall} toolCall
61-
* @returns {Promise<Message>}
92+
* @param {Record<string, unknown>} args
93+
* @returns {Promise<string>}
6294
*/
63-
async handleToolCall(toolCall) {
64-
const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}')
65-
if (rowEnd - rowStart > 5) {
66-
throw new Error('Do NOT request more than 5 rows.')
95+
async handleToolCall({ filename, offset = 0, limit = 5, orderBy }) {
96+
if (typeof filename !== 'string') {
97+
throw new Error('Expected filename to be a string')
98+
}
99+
const rowStart = validateInteger('offset', offset, 0)
100+
const rowEnd = rowStart + validateInteger('limit', limit, 1, 5)
101+
if (typeof orderBy !== 'undefined' && typeof orderBy !== 'string') {
102+
throw new Error('Expected orderBy to be a string')
67103
}
68104
const file = await asyncBufferFromFile(filename)
69-
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy })
105+
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy, compressors })
70106
let content = ''
71107
for (let i = rowStart; i < rowEnd; i++) {
72-
content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n`
108+
content += `Row ${i}: ${stringify(rows[i])}\n`
73109
}
74-
return { role: 'tool', content, tool_call_id: toolCall.id }
110+
return content
75111
},
76112
},
77113
]
114+
115+
/**
116+
* Validates that a value is an integer within the specified range. Max is inclusive.
117+
* @param {string} name - The name of the value being validated.
118+
* @param {unknown} value - The value to validate.
119+
* @param {number} min - The minimum allowed value (inclusive).
120+
* @param {number} [max] - The maximum allowed value (inclusive).
121+
* @returns {number}
122+
*/
123+
function validateInteger(name, value, min, max) {
124+
if (typeof value !== 'number' || isNaN(value)) {
125+
throw new Error(`Invalid number for ${name}: ${value}`)
126+
}
127+
if (!Number.isInteger(value)) {
128+
throw new Error(`Invalid number for ${name}: ${value}. Must be an integer.`)
129+
}
130+
if (value < min || max !== undefined && value > max) {
131+
throw new Error(`Invalid number for ${name}: ${value}. Must be between ${min} and ${max}.`)
132+
}
133+
return value
134+
}
135+
136+
function stringify(obj, limit = 1000) {
137+
const str = JSON.stringify(toJson(obj))
138+
return str.length <= limit ? str : str.slice(0, limit) + '…'
139+
}

bin/types.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export interface Message {
2020
export interface ToolHandler {
2121
emoji: string
2222
tool: Tool
23-
handleToolCall(toolCall: ToolCall): Promise<Message>
23+
handleToolCall(args: Record<string, unknown>): Promise<string>
2424
}
2525
interface ToolProperty {
2626
type: string

tsconfig.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
{
22
"compilerOptions": {
3+
"allowJs": true,
4+
"checkJs": true,
35
"target": "es2022",
46
"useDefineForClassFields": true,
57
"lib": ["es2022", "DOM", "DOM.Iterable"],

0 commit comments

Comments
 (0)