Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 45 additions & 39 deletions bin/chat.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { tools } from './tools.js'

const systemPrompt =
'You are a machine learning web application named "hyperparam". ' +
'You assist users with building high quality ML models by introspecting on their training set data. ' +
'The website and api are available at hyperparam.app. ' +
'Hyperparam uses LLMs to analyze their own training set. ' +
'It can generate the perplexity, entropy, and other metrics of the training set. ' +
'This allows users to find segments of their data set which are difficult to model. ' +
'This could be because the data is junk, or because the data requires deeper understanding. ' +
'This is essential for closing the loop on the ML lifecycle. ' +
'The quickest way to get started is to upload a dataset and start exploring.'
const systemPrompt = 'You are a machine learning web application named "Hyperparam".'
+ ' You assist users with building high quality ML models by introspecting on their training set data.'
+ ' The website and api are available at hyperparam.app.'
+ ' Hyperparam uses LLMs to analyze their own training set.'
+ ' It can generate the perplexity, entropy, and other metrics of the training set.'
+ ' This allows users to find segments of their data set which are difficult to model.'
+ ' This could be because the data is junk, or because the data requires deeper understanding.'
+ ' This is essential for closing the loop on the ML lifecycle.'
+ ' The quickest way to get started is to upload a dataset and start exploring.'
/** @type {Message} */
const systemMessage = { role: 'system', content: systemPrompt }

Expand All @@ -27,7 +26,7 @@ const colors = {
* @returns {Promise<Message>}
*/
async function sendToServer(chatInput) {
const response = await fetch('http://localhost:3000/api/functions/openai/chat', {
const response = await fetch('https://hyperparam.app/api/functions/openai/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(chatInput),
Expand All @@ -38,8 +37,10 @@ async function sendToServer(chatInput) {
}

// Process the streaming response
/** @type {Message} */
const streamResponse = { role: 'assistant', content: '' }
const reader = response.body.getReader()
const reader = response.body?.getReader()
if (!reader) throw new Error('No response body')
const decoder = new TextDecoder()
let buffer = ''

Expand All @@ -53,19 +54,19 @@ async function sendToServer(chatInput) {
for (const line of lines) {
if (!line.trim()) continue
try {
const jsonChunk = JSON.parse(line)
const { content, error } = jsonChunk
if (content) {
streamResponse.content += content
write(content)
const chunk = JSON.parse(line)
const { type, error } = chunk
if (type === 'response.output_text.delta') {
streamResponse.content += chunk.delta
write(chunk.delta)
} else if (error) {
console.error(error)
throw new Error(error)
} else if (jsonChunk.function) {
} else if (chunk.function) {
streamResponse.tool_calls ??= []
streamResponse.tool_calls.push(jsonChunk)
} else if (!jsonChunk.key && jsonChunk.content !== '') {
console.log('Unknown chunk', jsonChunk)
streamResponse.tool_calls.push(chunk)
} else if (!chunk.key) {
console.log('Unknown chunk', chunk)
}
} catch (err) {
console.error('Error parsing chunk', err)
Expand All @@ -85,19 +86,21 @@ async function sendToServer(chatInput) {
*/
async function sendMessages(messages) {
const chatInput = {
model: 'gpt-4o',
messages,
tools: tools.map(tool => tool.tool),
}
const response = await sendToServer(chatInput)
messages.push(response)
// handle tool results
if (response.tool_calls) {
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<Message> }[]} */
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<string> }[]} */
const toolResults = []
for (const toolCall of response.tool_calls) {
const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name)
if (tool) {
const result = tool.handleToolCall(toolCall)
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
const result = tool.handleToolCall(args)
toolResults.push({ toolCall, tool, result })
} else {
throw new Error(`Unknown tool: ${toolCall.function.name}`)
Expand All @@ -106,24 +109,27 @@ async function sendMessages(messages) {
write('\n')
for (const toolResult of toolResults) {
const { toolCall, tool } = toolResult
const result = await toolResult.result
try {
const content = await toolResult.result

// Construct function call message
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
const keys = Object.keys(args)
let func = toolCall.function.name
if (keys.length === 0) {
func += '()'
} else if (keys.length === 1) {
func += `(${args[keys[0]]})`
} else {
// transform to (arg1 = 111, arg2 = 222)
const pairs = keys.map(key => `${key} = ${args[key]}`)
func += `(${pairs.join(', ')})`
}
// Construct function call message
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
const entries = Object.entries(args)
let func = toolCall.function.name
if (entries.length === 0) {
func += '()'
} else {
// transform to (arg1 = 111, arg2 = 222)
const pairs = entries.map(([key, value]) => `${key} = ${value}`)
func += `(${pairs.join(', ')})`
}

write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
messages.push(result)
write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
messages.push({ role: 'tool', content, tool_call_id: toolCall.id })
} catch (error) {
write(colors.error, `\nError calling tool ${toolCall.function.name}: ${error.message}\n\n`, colors.normal)
messages.push({ role: 'tool', content: `Error calling tool ${toolCall.function.name}: ${error.message}`, tool_call_id: toolCall.id })
}
}
// send messages with tool results
await sendMessages(messages)
Expand Down
8 changes: 4 additions & 4 deletions bin/serve.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function handleStatic(filePath, range) {
if (!stats?.isFile()) {
return { status: 404, content: 'not found' }
}
const contentLength = stats.size
const fileSize = stats.size

// detect content type
const extname = path.extname(filePath)
Expand All @@ -133,8 +133,8 @@ async function handleStatic(filePath, range) {
// ranged requests
if (range) {
const [unit, ranges] = range.split('=')
if (unit === 'bytes') {
const [start, end] = ranges.split('-').map(Number)
if (unit === 'bytes' && ranges) {
const [start = 0, end = fileSize] = ranges.split('-').map(Number)

// convert fs.ReadStream to web stream
const fsStream = createReadStream(filePath, { start, end })
Expand All @@ -151,7 +151,7 @@ async function handleStatic(filePath, range) {
}

const content = await fs.readFile(filePath)
return { status: 200, content, contentLength, contentType }
return { status: 200, content, contentLength: fileSize, contentType }
}

/**
Expand Down
116 changes: 89 additions & 27 deletions bin/tools.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import fs from 'fs/promises'
import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet'
import { compressors } from 'hyparquet-compressors'

const fileLimit = 20 // limit to 20 files per page
/**
* @import { Message, ToolCall, ToolHandler } from './types.d.ts'
* @import { ToolHandler } from './types.d.ts'
* @type {ToolHandler[]}
*/
export const tools = [
Expand All @@ -12,66 +14,126 @@ export const tools = [
type: 'function',
function: {
name: 'list_files',
description: 'List the files in the current directory.',
description: `List the files in a directory. Files are listed recursively up to ${fileLimit} per page.`,
parameters: {
type: 'object',
properties: {
path: { type: 'string', description: 'The path to list files from (optional).' },
path: {
type: 'string',
description: 'The path to list files from. Optional, defaults to the current directory.',
},
filetype: {
type: 'string',
description: 'Optional file type to filter by, e.g. "parquet", "csv". If not provided, all files are listed.',
},
offset: {
type: 'number',
description: 'Skip offset number of files in the listing. Defaults to 0. Optional.',
},
},
},
},
},
/**
* @param {ToolCall} toolCall
* @returns {Promise<Message>}
* @param {Record<string, unknown>} args
* @returns {Promise<string>}
*/
async handleToolCall(toolCall) {
let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}')
async handleToolCall({ path = '.', filetype, offset = 0 }) {
if (typeof path !== 'string') {
throw new Error('Expected path to be a string')
}
if (path.includes('..') || path.includes('~')) {
throw new Error('Invalid path: ' + path)
}
// append to current directory
path = process.cwd() + '/' + path
// list files in the current directory
const filenames = await fs.readdir(path)
return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id }
if (typeof filetype !== 'undefined' && typeof filetype !== 'string') {
throw new Error('Expected filetype to be a string or undefined')
}
const start = validateInteger('offset', offset, 0)
// list files in the directory
const filenames = (await fs.readdir(path))
.filter(key => !filetype || key.endsWith(`.${filetype}`)) // filter by file type if provided
const limited = filenames.slice(start, start + fileLimit)
const end = start + limited.length
return `Files ${start + 1}..${end} of ${filenames.length}:\n${limited.join('\n')}`
},
},
{
emoji: '📄',
tool: {
type: 'function',
function: {
name: 'read_parquet',
description: 'Read rows from a parquet file. Do not request more than 5 rows.',
name: 'parquet_get_rows',
description: 'Get up to 5 rows of data from a parquet file.',
parameters: {
type: 'object',
properties: {
filename: { type: 'string', description: 'The name of the parquet file to read.' },
rowStart: { type: 'integer', description: 'The start row index.' },
rowEnd: { type: 'integer', description: 'The end row index.' },
orderBy: { type: 'string', description: 'The column name to sort by.' },
filename: {
type: 'string',
description: 'The name of the parquet file to read.',
},
offset: {
type: 'number',
description: 'The starting row index to fetch (0-indexed).',
},
limit: {
type: 'number',
description: 'The number of rows to fetch. Default 5. Maximum 5.',
},
orderBy: {
type: 'string',
description: 'The column name to sort by.',
},
},
required: ['filename'],
},
},
},
/**
* @param {ToolCall} toolCall
* @returns {Promise<Message>}
* @param {Record<string, unknown>} args
* @returns {Promise<string>}
*/
async handleToolCall(toolCall) {
const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}')
if (rowEnd - rowStart > 5) {
throw new Error('Do NOT request more than 5 rows.')
async handleToolCall({ filename, offset = 0, limit = 5, orderBy }) {
if (typeof filename !== 'string') {
throw new Error('Expected filename to be a string')
}
const rowStart = validateInteger('offset', offset, 0)
const rowEnd = rowStart + validateInteger('limit', limit, 1, 5)
if (typeof orderBy !== 'undefined' && typeof orderBy !== 'string') {
throw new Error('Expected orderBy to be a string')
}
const file = await asyncBufferFromFile(filename)
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy })
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy, compressors })
let content = ''
for (let i = rowStart; i < rowEnd; i++) {
content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n`
content += `Row ${i}: ${stringify(rows[i])}\n`
}
return { role: 'tool', content, tool_call_id: toolCall.id }
return content
},
},
]

/**
* Validates that a value is an integer within the specified range. Max is inclusive.
* @param {string} name - The name of the value being validated.
* @param {unknown} value - The value to validate.
* @param {number} min - The minimum allowed value (inclusive).
* @param {number} [max] - The maximum allowed value (inclusive).
* @returns {number}
*/
function validateInteger(name, value, min, max) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error(`Invalid number for ${name}: ${value}`)
}
if (!Number.isInteger(value)) {
throw new Error(`Invalid number for ${name}: ${value}. Must be an integer.`)
}
if (value < min || max !== undefined && value > max) {
throw new Error(`Invalid number for ${name}: ${value}. Must be between ${min} and ${max}.`)
}
return value
}

function stringify(obj, limit = 1000) {
const str = JSON.stringify(toJson(obj))
return str.length <= limit ? str : str.slice(0, limit) + '…'
}
2 changes: 1 addition & 1 deletion bin/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Message {
export interface ToolHandler {
emoji: string
tool: Tool
handleToolCall(toolCall: ToolCall): Promise<Message>
handleToolCall(args: Record<string, unknown>): Promise<string>
}
interface ToolProperty {
type: string
Expand Down
2 changes: 2 additions & 0 deletions tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"compilerOptions": {
"allowJs": true,
"checkJs": true,
"target": "es2022",
"useDefineForClassFields": true,
"lib": ["es2022", "DOM", "DOM.Iterable"],
Expand Down