diff --git a/agents/build.ts b/agents/build.ts index d339112f7..eb56ab74e 100644 --- a/agents/build.ts +++ b/agents/build.ts @@ -5,7 +5,7 @@ import dts from 'bun-plugin-dts'; await Bun.build({ - entrypoints: ['./src/index.ts', './src/tts/index.ts', './src/stt/index.ts'], + entrypoints: ['./src/index.ts', './src/tts/index.ts', './src/stt/index.ts', './src/cli.ts'], outdir: './dist', target: 'bun', // https://github.com/oven-sh/bun/blob/main/src/bundler/bundle_v2.zig#L2667 sourcemap: 'external', diff --git a/agents/package.json b/agents/package.json index 82ebe2cde..f661dc680 100644 --- a/agents/package.json +++ b/agents/package.json @@ -15,6 +15,7 @@ }, "devDependencies": { "@types/bun": "latest", + "@types/ws": "^8.5.10", "bun-plugin-dts": "^0.2.1", "eslint": "^8.57.0", "eslint-config-prettier": "^9.1.0", @@ -25,10 +26,11 @@ "typescript": "^5.0.0" }, "dependencies": { - "@livekit/protocol": "^1.12.0", + "@livekit/protocol": "^1.13.0", "commander": "^12.0.0", "livekit-server-sdk": "^2.1.2", "pino": "^8.19.0", - "pino-pretty": "^11.0.0" + "pino-pretty": "^11.0.0", + "ws": "^8.16.0" } } diff --git a/agents/src/cli.ts b/agents/src/cli.ts index c7610abf2..5388110e6 100644 --- a/agents/src/cli.ts +++ b/agents/src/cli.ts @@ -2,44 +2,77 @@ // // SPDX-License-Identifier: Apache-2.0 -import { version } from './index'; +import { version } from '.'; import { Option, Command } from 'commander'; +import { WorkerOptions, Worker } from './worker'; +import { EventEmitter } from 'events'; +import { log } from './log'; -const program = new Command(); -program - .name('agents') - .description('LiveKit Agents CLI') - .version(version) - .addOption( - new Option('--log-level', 'Set the logging level').choices([ - 'DEBUG', - 'INFO', - 'WARNING', - 'ERROR', - 'CRITICAL', - ]), - ); - -program - .command('start') - .description('Start the worker') - .addOption( - new Option('--url ', 'LiveKit server or Cloud project websocket URL') - .makeOptionMandatory(true) - .env('LIVEKIT_URL'), - ) - .addOption( - new Option('--api-key ', "LiveKit server or Cloud project's API key") - .makeOptionMandatory(true) - .env('LIVEKIT_API_KEY'), - ) - .addOption( - new Option('--api-secret ', "LiveKit server or Cloud project's API secret") - .makeOptionMandatory(true) - .env('LIVEKIT_API_SECRET'), - ) - .action(() => { - return; +type CliArgs = { + opts: WorkerOptions; + logLevel: string; + production: boolean; + watch: boolean; + event?: EventEmitter; +}; + +const runWorker = async (args: CliArgs) => { + log.level = args.logLevel; + const worker = new Worker(args.opts); + + process.on('SIGINT', async () => { + await worker.close(); + process.exit(130); // SIGINT exit code }); -program.parse(); + try { + await worker.run(); + } catch { + log.fatal('worker failed'); + } +}; + +export const runApp = (opts: WorkerOptions) => { + const program = new Command() + .name('agents') + .description('LiveKit Agents CLI') + .version(version) + .addOption( + new Option('--log-level ', 'Set the logging level') + .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) + .default('trace'), + ) + .addOption( + new Option('--url ', 'LiveKit server or Cloud project websocket URL') + .makeOptionMandatory(true) + .env('LIVEKIT_URL'), + ) + .addOption( + new Option('--api-key ', "LiveKit server or Cloud project's API key") + .makeOptionMandatory(true) + .env('LIVEKIT_API_KEY'), + ) + .addOption( + new Option('--api-secret ', "LiveKit server or Cloud project's API secret") + .makeOptionMandatory(true) + .env('LIVEKIT_API_SECRET'), + ) + + program + .command('start') + .description('Start the worker in production mode') + .action(() => { + const options = program.optsWithGlobals() + opts.wsURL = options.url || opts.wsURL; + opts.apiKey = options.apiKey || opts.apiKey; + opts.apiSecret = options.apiSecret || opts.apiSecret; + runWorker({ + opts, + production: true, + watch: false, + logLevel: options.logLevel, + }); + }); + + program.parse(); +}; diff --git a/agents/src/index.ts b/agents/src/index.ts index 4bb696411..6db2e8fd8 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -5,3 +5,6 @@ export * from './vad'; export * from './plugin'; export * from './version'; +export * from './job_context'; +export * from './job_request'; +export * from './worker'; diff --git a/agents/src/ipc/job_process.ts b/agents/src/ipc/job_process.ts index ad4bed7b3..58f6ef60c 100644 --- a/agents/src/ipc/job_process.ts +++ b/agents/src/ipc/job_process.ts @@ -20,10 +20,10 @@ import { runJob } from './job_main'; import { EventEmitter } from 'events'; import { log } from '../log'; -const START_TIMEOUT = 90; -const PING_INTERVAL = 5; -const PING_TIMEOUT = 90; -const HIGH_PING_THRESHOLD = 10; // milliseconds +const START_TIMEOUT = 90 * 1000; +const PING_INTERVAL = 5 * 1000; +const PING_TIMEOUT = 90 * 1000; +const HIGH_PING_THRESHOLD = 10; export class JobProcess { #job: Job; @@ -103,6 +103,7 @@ export class JobProcess { const delay = Date.now() - msg.timestamp; if (delay > HIGH_PING_THRESHOLD) { this.logger.warn(`job is unresponsive (${delay}ms delay)`); + // @ts-expect-error: this actually works fine types/bun doesn't have a typedecl for it yet pongTimeout.refresh(); } } else if (msg instanceof UserExit || msg instanceof ShutdownResponse) { diff --git a/agents/src/ipc/protocol.ts b/agents/src/ipc/protocol.ts index d176e0cd5..1137f7b90 100644 --- a/agents/src/ipc/protocol.ts +++ b/agents/src/ipc/protocol.ts @@ -31,7 +31,7 @@ export class StartJobRequest implements Message { export class StartJobResponse implements Message { static MSG_ID = 1; - err: Error | undefined; + err?: Error; get MSG_ID(): number { return StartJobResponse.MSG_ID; diff --git a/agents/src/job_context.ts b/agents/src/job_context.ts index 1a39775eb..13abf41a5 100644 --- a/agents/src/job_context.ts +++ b/agents/src/job_context.ts @@ -9,7 +9,7 @@ import { EventEmitter } from 'events'; export class JobContext { #job: Job; #room: Room; - #publisher: RemoteParticipant | undefined; + #publisher?: RemoteParticipant; tx: EventEmitter; constructor( diff --git a/agents/src/job_request.ts b/agents/src/job_request.ts index 4f2d9def9..a2ecb08c2 100644 --- a/agents/src/job_request.ts +++ b/agents/src/job_request.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { JobContext } from './job_context'; -import { VideoGrant } from 'livekit-server-sdk'; import { Job, ParticipantInfo, Room } from '@livekit/protocol'; import { log } from './log'; import { EventEmitter } from 'events'; @@ -35,16 +34,15 @@ export type AcceptData = { entry: AgentEntry; autoSubscribe: AutoSubscribe; autoDisconnect: AutoDisconnect; - grants: VideoGrant; name: string; identity: string; metadata: string; assign: EventEmitter; }; -type AvailRes = { +export type AvailRes = { avail: boolean; - data: AcceptData | undefined; + data?: AcceptData; }; export class JobRequest { @@ -91,7 +89,6 @@ export class JobRequest { entry: AgentEntry, autoSubscribe: AutoSubscribe = AutoSubscribe.SUBSCRIBE_ALL, autoDisconnect: AutoDisconnect = AutoDisconnect.ROOM_EMPTY, - grants: VideoGrant, name: string = '', identity: string = '', metadata: string = '', @@ -110,7 +107,6 @@ export class JobRequest { entry, autoSubscribe, autoDisconnect, - grants, name, identity, metadata, diff --git a/agents/src/stt/stream_adapter.ts b/agents/src/stt/stream_adapter.ts index 29dd97331..e1dc8487d 100644 --- a/agents/src/stt/stream_adapter.ts +++ b/agents/src/stt/stream_adapter.ts @@ -12,7 +12,7 @@ export class StreamAdapterWrapper extends SpeechStream { stt: STT; vadStream: VADStream; eventQueue: (SpeechEvent | undefined)[]; - language: string | undefined; + language?: string; task: { run: Promise; cancel: () => void; diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 8b9204118..c58051bee 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -49,7 +49,7 @@ export abstract class STT { this.#streamingSupported = streamingSupported; } - abstract recognize(buffer: AudioBuffer, language: string | undefined): Promise; + abstract recognize(buffer: AudioBuffer, language?: string): Promise; abstract stream(language: string | undefined): SpeechStream; diff --git a/agents/src/tokenize.ts b/agents/src/tokenize.ts index ff7f658d6..0dee52c28 100644 --- a/agents/src/tokenize.ts +++ b/agents/src/tokenize.ts @@ -7,7 +7,7 @@ export interface SegmentedSentence { } export abstract class SentenceTokenizer { - abstract tokenize(text: string, language: string | undefined): SegmentedSentence[]; + abstract tokenize(text: string, language?: string): SegmentedSentence[]; abstract stream(language: string | undefined): SentenceStream; } diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index ddb18ff73..c352f61ae 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -17,7 +17,7 @@ export enum SynthesisEventType { export class SynthesisEvent { type: SynthesisEventType; - audio: SynthesizedAudio | undefined; + audio?: SynthesizedAudio; constructor(type: SynthesisEventType, audio: SynthesizedAudio | undefined = undefined) { this.type = type; @@ -26,7 +26,7 @@ export class SynthesisEvent { } export abstract class SynthesizeStream implements IterableIterator { - abstract pushText(token: string | undefined): void; + abstract pushText(token?: string): void; markSegmentEnd() { this.pushText(undefined); diff --git a/agents/src/worker.ts b/agents/src/worker.ts new file mode 100644 index 000000000..339b1e488 --- /dev/null +++ b/agents/src/worker.ts @@ -0,0 +1,357 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +import os from 'os'; +import { WebSocket } from 'ws'; +import { AvailRes, JobRequest } from './job_request'; +import { + JobType, + Job, + WorkerMessage, + ParticipantPermission, + ServerMessage, + JobAssignment, +} from '@livekit/protocol'; +import { AcceptData } from './job_request'; +import { HTTPServer } from './http_server'; +import { log } from './log'; +import { version } from './version'; +import { AccessToken } from 'livekit-server-sdk'; +import { EventEmitter } from 'events'; +import { JobProcess } from './ipc/job_process'; + +const MAX_RECONNECT_ATTEMPTS = 10; +const ASSIGNMENT_TIMEOUT = 15 * 1000; +const LOAD_INTERVAL = 5 * 1000; + +const cpuLoad = (): number => + (os + .cpus() + .reduce( + (acc, x) => acc + x.times.idle / Object.values(x.times).reduce((acc, x) => acc + x, 0), + 0, + ) / + os.cpus().length) * + 100; + +class WorkerPermissions { + canPublish: boolean; + canSubscribe: boolean; + canPublishData: boolean; + canUpdateMetadata: boolean; + hidden: boolean; + + constructor( + canPublish = true, + canSubscribe = true, + canPublishData = true, + canUpdateMetadata = true, + hidden = false, + ) { + this.canPublish = canPublish; + this.canSubscribe = canSubscribe; + this.canPublishData = canPublishData; + this.canUpdateMetadata = canUpdateMetadata; + this.hidden = hidden; + } +} + +export class WorkerOptions { + requestFunc: (arg: JobRequest) => Promise; + cpuLoadFunc: () => number; + namespace: string; + permissions: WorkerPermissions; + workerType: JobType; + maxRetry: number; + wsURL: string; + apiKey?: string; + apiSecret?: string; + host: string; + port: number; + + constructor({ + requestFunc, + cpuLoadFunc = cpuLoad, + namespace = 'default', + permissions = new WorkerPermissions(), + workerType = JobType.JT_PUBLISHER, + maxRetry = MAX_RECONNECT_ATTEMPTS, + wsURL = 'ws://localhost:7880', + apiKey = undefined, + apiSecret = undefined, + host = 'localhost', + port = 8081, + }: { + requestFunc: (arg: JobRequest) => Promise; + cpuLoadFunc?: () => number; + namespace?: string; + permissions?: WorkerPermissions; + workerType?: JobType; + maxRetry?: number; + wsURL?: string; + apiKey?: string; + apiSecret?: string; + host?: string; + port?: number; + }) { + this.requestFunc = requestFunc; + this.cpuLoadFunc = cpuLoadFunc; + this.namespace = namespace; + this.permissions = permissions; + this.workerType = workerType; + this.maxRetry = maxRetry; + this.wsURL = wsURL; + this.apiKey = apiKey; + this.apiSecret = apiSecret; + this.host = host; + this.port = port; + } +} + +class ActiveJob { + job: Job; + acceptData: AcceptData; + + constructor(job: Job, acceptData: AcceptData) { + this.job = job; + this.acceptData = acceptData; + } +} + +export class Worker { + opts: WorkerOptions; + #id = 'unregistered'; + session: WebSocket | undefined = undefined; + closed = false; + httpServer: HTTPServer; + logger = log.child({ version }); + event = new EventEmitter(); + pending: { [id: string]: { value?: JobAssignment } } = {}; + processes: { [id: string]: { proc: JobProcess; activeJob: ActiveJob } } = {}; + + constructor(opts: WorkerOptions) { + opts.wsURL = opts.wsURL || process.env.LIVEKIT_URL || ''; + opts.apiKey = opts.apiKey || process.env.LIVEKIT_API_KEY || ''; + opts.apiSecret = opts.apiSecret || process.env.LIVEKIT_API_SECRET || ''; + + this.opts = opts; + this.httpServer = new HTTPServer(opts.host, opts.port); + } + + get id(): string { + return this.#id; + } + + async run() { + this.logger.info('starting worker'); + + if (this.opts.wsURL === '') throw new Error('--url is required, or set LIVEKIT_URL env var'); + if (this.opts.apiKey === '') + throw new Error('--api-key is required, or set LIVEKIT_API_KEY env var'); + if (this.opts.apiSecret === '') + throw new Error('--api-secret is required, or set LIVEKIT_API_SECRET env var'); + + const workerWS = async () => { + // const retries = 0; + while (!this.closed) { + const token = new AccessToken(this.opts.apiKey, this.opts.apiSecret); + token.addGrant({ agent: true }); + const jwt = await token.toJwt(); + + const url = new URL(this.opts.wsURL); + url.protocol = url.protocol.replace('http', 'ws'); + this.session = new WebSocket(url + 'agent', { + headers: { authorization: 'Bearer ' + jwt }, + }); + this.session.on('open', () => { + this.session!.removeAllListeners('close'); + this.runWS(this.session!); + }); + return; + + // TODO(nbsp): retries that actually work + // if (this.session.readyState !== WebSocket.OPEN) { + // if (this.closed) return; + // if (retries >= this.opts.maxRetry) { + // throw new Error(`failed to connect to LiveKit server after ${retries} attempts: ${e}`); + // } + + // const delay = Math.min(retries * 2, 10); + // retries++; + + // this.logger.warn( + // `failed to connect to LiveKit server, retrying in ${delay} seconds: ${e} (${retries}/${this.opts.maxRetry})`, + // ); + // await new Promise((resolve) => setTimeout(resolve, delay)); + // } + } + }; + + await Promise.all([workerWS(), this.httpServer.run()]); + } + + startProcess(job: Job, url: string, token: string, acceptData: AcceptData) { + const proc = new JobProcess(job, url, token, acceptData.entry); + this.processes[job.id] = { proc, activeJob: new ActiveJob(job, acceptData) }; + new Promise((_, reject) => { + try { + proc.run(); + } catch (e) { + proc.logger.error(`error running job process ${proc.job.id}`); + reject(e); + } finally { + delete this.processes[job.id]; + } + }); + } + + runWS(ws: WebSocket) { + let closingWS = false; + + const send = (msg: WorkerMessage) => { + if (closingWS) { + this.event.off('worker_msg', send); + return; + } + ws.send(msg.toBinary()); + }; + this.event.on('worker_msg', send); + + ws.addEventListener('close', () => { + closingWS = true; + if (!this.closed) throw new Error('worker connection closed unexpectedly'); + }); + + ws.addEventListener('message', (event) => { + if (event.type !== 'message') { + this.logger.warn('unexpected message type: ' + event.type); + return; + } + + const msg = new ServerMessage(); + msg.fromBinary(event.data as Uint8Array); + switch (msg.message.case) { + case 'register': { + this.#id = msg.message.value.workerId; + log + .child({ id: this.id, server_info: msg.message.value.serverInfo }) + .info('registered worker'); + break; + } + case 'availability': { + const tx = new EventEmitter(); + const req = new JobRequest(msg.message.value.job!, tx); + this.event.on('recv', (av: AvailRes) => { + const msg = new WorkerMessage({ + message: { + case: 'availability', + value: { + available: av.avail, + jobId: req.id, + participantIdentity: av.data?.identity, + participantName: av.data?.name, + participantMetadata: av.data?.metadata, + }, + }, + }); + + this.pending[req.id] = { value: undefined }; + this.event.emit('worker_msg', msg); + + new Promise((_, reject) => { + const timer = setTimeout(() => { + reject(new Error(`assignment for job ${req.id} timed out`)); + }, ASSIGNMENT_TIMEOUT); + Promise.resolve(this.pending[req.id].value).then((value) => { + clearTimeout(timer); + const url = value?.url || this.opts.wsURL; + + try { + this.opts.requestFunc(req); + } catch (e) { + log.child({ req }).error(`user request hadnler for job ${req.id} failed`); + } finally { + if (!req.answered) { + log + .child({ req }) + .error(`no answer for job ${req.id}, automatically rejecting the job`); + this.event.emit( + 'worker_msg', + new WorkerMessage({ + message: { + case: 'availability', + value: { + available: false, + }, + }, + }), + ); + } + } + + this.startProcess(value!.job!, url, value!.token, av.data!); + }); + }); + }); + + break; + } + case 'assignment': { + const job = msg.message.value.job!; + if (job.id in this.pending) { + const task = this.pending[job.id]; + delete this.pending[job.id]; + task.value = msg.message.value; + } else { + log.child({ job }).warn('received assignment for unknown job ' + job.id); + } + break; + } + } + }); + + this.event.emit( + 'worker_msg', + new WorkerMessage({ + message: { + case: 'register', + value: { + type: this.opts.workerType, + namespace: this.opts.namespace, + allowedPermissions: new ParticipantPermission({ + canPublish: this.opts.permissions.canPublish, + canSubscribe: this.opts.permissions.canSubscribe, + canPublishData: this.opts.permissions.canPublishData, + hidden: this.opts.permissions.hidden, + agent: true, + }), + version, + }, + }, + }), + ); + + const loadMonitor = setInterval(() => { + if (closingWS) clearInterval(loadMonitor); + this.event.emit( + 'worker_msg', + new WorkerMessage({ + message: { + case: 'updateWorker', + value: { + load: cpuLoad(), + }, + }, + }), + ); + }, LOAD_INTERVAL); + } + + async close() { + if (this.closed) return; + this.logger.info('shutting down worker'); + await this.httpServer.close(); + this.session; + } +} diff --git a/bun.lockb b/bun.lockb index f1bd9a6df..42a048d4d 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/examples/minimal.ts b/examples/minimal.ts new file mode 100644 index 000000000..b78671331 --- /dev/null +++ b/examples/minimal.ts @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +import { runApp } from '../agents/src/cli' +import { JobContext, JobRequest, WorkerOptions } from '../agents/src' + +const requestFunc = async (req: JobRequest) => { + console.log('received request', req) + await req.accept(async (_: JobContext) => { + console.log('starting voice assistant...') + + // etc + }) +} + +runApp(new WorkerOptions({ requestFunc }))