Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 118 additions & 48 deletions bin/chat.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import http from 'http' // TODO: https
import { tools } from './tools.js'

const systemPrompt = 'You are a machine learning web application named "hyperparam". ' +
const systemPrompt =
'You are a machine learning web application named "hyperparam". ' +
'You assist users with building high quality ML models by introspecting on their training set data. ' +
'The website and api are available at hyperparam.app. ' +
'Hyperparam uses LLMs to analyze their own training set. ' +
Expand All @@ -9,49 +10,124 @@ const systemPrompt = 'You are a machine learning web application named "hyperpar
'This could be because the data is junk, or because the data requires deeper understanding. ' +
'This is essential for closing the loop on the ML lifecycle. ' +
'The quickest way to get started is to upload a dataset and start exploring.'
const messages = [{ role: 'system', content: systemPrompt }]
/** @type {Message} */
const systemMessage = { role: 'system', content: systemPrompt }

const colors = {
system: '\x1b[36m', // cyan
user: '\x1b[33m', // yellow
tool: '\x1b[90m', // gray
error: '\x1b[31m', // red
normal: '\x1b[0m', // reset
}

/**
* @import { Message } from './types.d.ts'
* @param {Object} chatInput
* @returns {Promise<string>}
* @returns {Promise<Message>}
*/
function sendToServer(chatInput) {
return new Promise((resolve, reject) => {
const json = JSON.stringify(chatInput)
const options = {
hostname: 'localhost',
port: 3000,
path: '/api/functions/openai/chat',
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Content-Length': json.length,
},
}
const req = http.request(options, res => {
let responseBody = ''
res.on('data', chunk => {
if (chunk[0] !== 123) return // {
try {
const { content } = JSON.parse(chunk)
responseBody += content
async function sendToServer(chatInput) {
const response = await fetch('http://localhost:3000/api/functions/openai/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(chatInput),
})

if (!response.ok) {
throw new Error(`Request failed: ${response.status}`)
}

// Process the streaming response
const streamResponse = { role: 'assistant', content: '' }
const reader = response.body.getReader()
const decoder = new TextDecoder()
let buffer = ''

while (true) {
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
// Keep the last line in the buffer
buffer = lines.pop() || ''
for (const line of lines) {
if (!line.trim()) continue
try {
const jsonChunk = JSON.parse(line)
const { content, error } = jsonChunk
if (content) {
streamResponse.content += content
write(content)
} catch (error) {
reject(error)
} else if (error) {
console.error(error)
throw new Error(error)
} else if (jsonChunk.function) {
streamResponse.tool_calls ??= []
streamResponse.tool_calls.push(jsonChunk)
} else if (!jsonChunk.key && jsonChunk.content !== '') {
console.log('Unknown chunk', jsonChunk)
}
})
res.on('end', () => {
if (res.statusCode === 200) {
resolve(responseBody)
} else {
reject(new Error(`request failed: ${res.statusCode}`))
}
})
})
req.on('error', reject)
req.write(json)
req.end()
})
} catch (err) {
console.error('Error parsing chunk', err)
}
}
}
return streamResponse
}

/**
* Send messages to the server and handle tool calls.
* Will mutate the messages array!
*
* @import { ToolCall, ToolHandler } from './types.d.ts'
* @param {Message[]} messages
* @returns {Promise<void>}
*/
async function sendMessages(messages) {
const chatInput = {
messages,
tools: tools.map(tool => tool.tool),
}
const response = await sendToServer(chatInput)
messages.push(response)
// handle tool results
if (response.tool_calls) {
/** @type {{ toolCall: ToolCall, tool: ToolHandler, result: Promise<Message> }[]} */
const toolResults = []
for (const toolCall of response.tool_calls) {
const tool = tools.find(tool => tool.tool.function.name === toolCall.function.name)
if (tool) {
const result = tool.handleToolCall(toolCall)
toolResults.push({ toolCall, tool, result })
} else {
throw new Error(`Unknown tool: ${toolCall.function.name}`)
}
}
write('\n')
for (const toolResult of toolResults) {
const { toolCall, tool } = toolResult
const result = await toolResult.result

// Construct function call message
const args = JSON.parse(toolCall.function?.arguments ?? '{}')
const keys = Object.keys(args)
let func = toolCall.function.name
if (keys.length === 0) {
func += '()'
} else if (keys.length === 1) {
func += `(${args[keys[0]]})`
} else {
// transform to (arg1 = 111, arg2 = 222)
const pairs = keys.map(key => `${key} = ${args[key]}`)
func += `(${pairs.join(', ')})`
}

write(colors.tool, `${tool.emoji} ${func}`, colors.normal, '\n\n')
messages.push(result)
}
// send messages with tool results
await sendMessages(messages)
}
}

/**
Expand All @@ -62,15 +138,10 @@ function write(...args) {
}

export function chat() {
/** @type {Message[]} */
const messages = [systemMessage]
process.stdin.setEncoding('utf-8')

const colors = {
system: '\x1b[36m', // cyan
user: '\x1b[33m', // yellow
error: '\x1b[31m', // red
normal: '\x1b[0m', // reset
}

write(colors.system, 'question: ', colors.normal)

process.stdin.on('data', async (/** @type {string} */ input) => {
Expand All @@ -81,8 +152,7 @@ export function chat() {
try {
write(colors.user, 'answer: ', colors.normal)
messages.push({ role: 'user', content: input.trim() })
const response = await sendToServer({ messages })
messages.push({ role: 'assistant', content: response })
await sendMessages(messages)
} catch (error) {
console.error(colors.error, '\n' + error)
} finally {
Expand Down
2 changes: 1 addition & 1 deletion bin/cli.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ if (arg === 'chat') {
*/
function checkForUpdates() {
const currentVersion = packageJson.version
return void fetch('https://registry.npmjs.org/hyperparam/latest')
return fetch('https://registry.npmjs.org/hyperparam/latest')
.then(response => response.json())
.then(data => {
const latestVersion = data.version
Expand Down
77 changes: 77 additions & 0 deletions bin/tools.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import fs from 'fs/promises'
import { asyncBufferFromFile, parquetQuery, toJson } from 'hyparquet'

/**
* @import { Message, ToolCall, ToolHandler } from './types.d.ts'
* @type {ToolHandler[]}
*/
export const tools = [
{
emoji: '📂',
tool: {
type: 'function',
function: {
name: 'list_files',
description: 'List the files in the current directory.',
parameters: {
type: 'object',
properties: {
path: { type: 'string', description: 'The path to list files from (optional).' },
},
},
},
},
/**
* @param {ToolCall} toolCall
* @returns {Promise<Message>}
*/
async handleToolCall(toolCall) {
let { path = '.' } = JSON.parse(toolCall.function.arguments || '{}')
if (path.includes('..') || path.includes('~')) {
throw new Error('Invalid path: ' + path)
}
// append to current directory
path = process.cwd() + '/' + path
// list files in the current directory
const filenames = await fs.readdir(path)
return { role: 'tool', content: `Files:\n${filenames.join('\n')}`, tool_call_id: toolCall.id }
},
},
{
emoji: '📄',
tool: {
type: 'function',
function: {
name: 'read_parquet',
description: 'Read rows from a parquet file. Do not request more than 5 rows.',
parameters: {
type: 'object',
properties: {
filename: { type: 'string', description: 'The name of the parquet file to read.' },
rowStart: { type: 'integer', description: 'The start row index.' },
rowEnd: { type: 'integer', description: 'The end row index.' },
orderBy: { type: 'string', description: 'The column name to sort by.' },
},
required: ['filename'],
},
},
},
/**
* @param {ToolCall} toolCall
* @returns {Promise<Message>}
*/
async handleToolCall(toolCall) {
const { filename, rowStart = 0, rowEnd = 5, orderBy } = JSON.parse(toolCall.function.arguments || '{}')
if (rowEnd - rowStart > 5) {
throw new Error('Do NOT request more than 5 rows.')
}
const file = await asyncBufferFromFile(filename)
const rows = await parquetQuery({ file, rowStart, rowEnd, orderBy })
let content = ''
for (let i = rowStart; i < rowEnd; i++) {
content += `Row ${i}: ${JSON.stringify(toJson(rows[i]))}\n`
}
return { role: 'tool', content, tool_call_id: toolCall.id }
},
},
]
43 changes: 43 additions & 0 deletions bin/types.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
export interface ToolCall {
id: string
type: 'function'
function: {
name: string
arguments?: string
}
}

export type Role = 'system' | 'user' | 'assistant' | 'tool'

export interface Message {
role: Role
content: string
tool_calls?: ToolCall[]
tool_call_id?: string
error?: string
}

export interface ToolHandler {
emoji: string
tool: Tool
handleToolCall(toolCall: ToolCall): Promise<Message>
}
interface ToolProperty {
type: string
description: string
}

export interface Tool {
type: 'function'
function: {
name: string
description: string
parameters?: {
type: 'object'
properties: Record<string, ToolProperty>
required?: string[]
additionalProperties?: boolean
},
strict?: boolean
}
}
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
},
"dependencies": {
"hightable": "0.12.1",
"hyparquet": "1.8.7",
"hyparquet": "1.9.0",
"hyparquet-compressors": "1.0.0",
"react": "18.3.1",
"react-dom": "18.3.1"
},
"devDependencies": {
"@eslint/js": "9.22.0",
"@testing-library/react": "16.2.0",
"@types/node": "22.13.9",
"@types/node": "22.13.10",
"@types/react": "19.0.10",
"@types/react-dom": "19.0.4",
"@vitejs/plugin-react": "4.3.4",
Expand Down