Skip to content

Commit c54c21e

Browse files
authored
brianyin/ajs-314-preemptive-generation (#798)
1 parent ac8214d commit c54c21e

File tree

12 files changed

+430
-75
lines changed

12 files changed

+430
-75
lines changed

.changeset/fine-buckets-sink.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@livekit/agents': patch
3+
---
4+
5+
Add preemptive generation

agents/src/inference/stt.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
//
33
// SPDX-License-Identifier: Apache-2.0
44
import { type AudioFrame } from '@livekit/rtc-node';
5-
import { type RawData, WebSocket } from 'ws';
5+
import type { WebSocket } from 'ws';
6+
import { type RawData } from 'ws';
67
import { APIError, APIStatusError } from '../_exceptions.js';
78
import { AudioByteStream } from '../audio.js';
89
import { log } from '../log.js';

agents/src/inference/tts.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ import { AudioByteStream } from '../audio.js';
88
import { log } from '../log.js';
99
import { createStreamChannel } from '../stream/stream_channel.js';
1010
import { basic as tokenizeBasic } from '../tokenize/index.js';
11-
import {
12-
SynthesizeStream as BaseSynthesizeStream,
13-
TTS as BaseTTS,
14-
ChunkedStream,
15-
} from '../tts/index.js';
11+
import type { ChunkedStream } from '../tts/index.js';
12+
import { SynthesizeStream as BaseSynthesizeStream, TTS as BaseTTS } from '../tts/index.js';
1613
import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js';
1714
import { shortuuid } from '../utils.js';
1815
import {

agents/src/llm/tool_context.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,50 @@ export type ToolContext<UserData = UnknownUserData> = {
187187
[name: string]: FunctionTool<any, UserData, any>;
188188
};
189189

190+
export function isSameToolContext(ctx1: ToolContext, ctx2: ToolContext): boolean {
191+
const toolNames = new Set(Object.keys(ctx1));
192+
const toolNames2 = new Set(Object.keys(ctx2));
193+
194+
if (toolNames.size !== toolNames2.size) {
195+
return false;
196+
}
197+
198+
for (const name of toolNames) {
199+
if (!toolNames2.has(name)) {
200+
return false;
201+
}
202+
203+
const tool1 = ctx1[name];
204+
const tool2 = ctx2[name];
205+
206+
if (!tool1 || !tool2) {
207+
return false;
208+
}
209+
210+
if (tool1.description !== tool2.description) {
211+
return false;
212+
}
213+
}
214+
215+
return true;
216+
}
217+
218+
export function isSameToolChoice(choice1: ToolChoice | null, choice2: ToolChoice | null): boolean {
219+
if (choice1 === choice2) {
220+
return true;
221+
}
222+
if (choice1 === null || choice2 === null) {
223+
return false;
224+
}
225+
if (typeof choice1 === 'string' && typeof choice2 === 'string') {
226+
return choice1 === choice2;
227+
}
228+
if (typeof choice1 === 'object' && typeof choice2 === 'object') {
229+
return choice1.type === choice2.type && choice1.function.name === choice2.function.name;
230+
}
231+
return false;
232+
}
233+
190234
/**
191235
* Create a function tool with inferred parameters from the schema.
192236
*/

agents/src/metrics/base.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ export type EOUMetrics = {
9191
* Time taken to invoke the user's `Agent.onUserTurnCompleted` callback.
9292
*/
9393
onUserTurnCompletedDelayMs: number;
94+
/**
95+
* The time the user stopped speaking.
96+
*/
97+
lastSpeakingTimeMs: number;
98+
/**
99+
* The ID of the speech handle.
100+
*/
94101
speechId?: string;
95102
};
96103

agents/src/stt/stt.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ export enum SpeechEventType {
3838
END_OF_SPEECH = 3,
3939
/** Usage event, emitted periodically to indicate usage metrics. */
4040
RECOGNITION_USAGE = 4,
41+
/**
42+
* Preflight transcript, emitted before final transcript when STT has high confidence
43+
* but hasn't fully committed yet. Includes all pre-committed transcripts including
44+
* final transcript from the previous STT run.
45+
*/
46+
PREFLIGHT_TRANSCRIPT = 5,
4147
}
4248

4349
/** SpeechData contains metadata about this {@link SpeechEvent}. */

agents/src/voice/agent_activity.ts

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
type ToolContext,
2323
} from '../llm/index.js';
2424
import type { LLMError } from '../llm/llm.js';
25+
import { isSameToolChoice, isSameToolContext } from '../llm/tool_context.js';
2526
import { log } from '../log.js';
2627
import type {
2728
EOUMetrics,
@@ -43,6 +44,7 @@ import { type AgentSession, type TurnDetectionMode } from './agent_session.js';
4344
import {
4445
AudioRecognition,
4546
type EndOfTurnInfo,
47+
type PreemptiveGenerationInfo,
4648
type RecognitionHooks,
4749
type _TurnDetector,
4850
} from './audio_recognition.js';
@@ -71,6 +73,16 @@ import { SpeechHandle } from './speech_handle.js';
7173
// equivalent to Python's contextvars
7274
const speechHandleStorage = new AsyncLocalStorage<SpeechHandle>();
7375

76+
interface PreemptiveGeneration {
77+
speechHandle: SpeechHandle;
78+
userMessage: ChatMessage;
79+
info: PreemptiveGenerationInfo;
80+
chatCtx: ChatContext;
81+
tools: ToolContext;
82+
toolChoice: ToolChoice | null;
83+
createdAt: number;
84+
}
85+
7486
export class AgentActivity implements RecognitionHooks {
7587
private static readonly REPLY_TASK_CANCEL_TIMEOUT = 5000;
7688
private started = false;
@@ -87,6 +99,7 @@ export class AgentActivity implements RecognitionHooks {
8799
private audioStream = new DeferredReadableStream<AudioFrame>();
88100
// default to null as None, which maps to the default provider tool choice value
89101
private toolChoice: ToolChoice | null = null;
102+
private _preemptiveGeneration?: PreemptiveGeneration;
90103

91104
agent: Agent;
92105
agentSession: AgentSession;
@@ -589,8 +602,12 @@ export class AgentActivity implements RecognitionHooks {
589602
this.agentSession._updateUserState('speaking');
590603
}
591604

592-
onEndOfSpeech(_ev: VADEvent): void {
593-
this.agentSession._updateUserState('listening');
605+
onEndOfSpeech(ev: VADEvent): void {
606+
let speechEndTime = Date.now();
607+
if (ev) {
608+
speechEndTime = speechEndTime - ev.silenceDuration;
609+
}
610+
this.agentSession._updateUserState('listening', speechEndTime);
594611
}
595612

596613
onVADInferenceDone(ev: VADEvent): void {
@@ -664,6 +681,55 @@ export class AgentActivity implements RecognitionHooks {
664681
);
665682
}
666683

684+
onPreemptiveGeneration(info: PreemptiveGenerationInfo): void {
685+
if (
686+
!this.agentSession.options.preemptiveGeneration ||
687+
this.draining ||
688+
(this._currentSpeech !== undefined && !this._currentSpeech.interrupted) ||
689+
!(this.llm instanceof LLM)
690+
) {
691+
return;
692+
}
693+
694+
this.cancelPreemptiveGeneration();
695+
696+
this.logger.info(
697+
{
698+
newTranscript: info.newTranscript,
699+
transcriptConfidence: info.transcriptConfidence,
700+
},
701+
'starting preemptive generation',
702+
);
703+
704+
const userMessage = ChatMessage.create({
705+
role: 'user',
706+
content: info.newTranscript,
707+
});
708+
const chatCtx = this.agent.chatCtx.copy();
709+
const speechHandle = this.generateReply({
710+
userMessage,
711+
chatCtx,
712+
scheduleSpeech: false,
713+
});
714+
715+
this._preemptiveGeneration = {
716+
speechHandle,
717+
userMessage,
718+
info,
719+
chatCtx: chatCtx.copy(),
720+
tools: { ...this.tools },
721+
toolChoice: this.toolChoice,
722+
createdAt: Date.now(),
723+
};
724+
}
725+
726+
private cancelPreemptiveGeneration(): void {
727+
if (this._preemptiveGeneration !== undefined) {
728+
this._preemptiveGeneration.speechHandle._cancel();
729+
this._preemptiveGeneration = undefined;
730+
}
731+
}
732+
667733
private createSpeechTask(options: {
668734
task: Task<void>;
669735
ownedSpeechHandle?: SpeechHandle;
@@ -694,6 +760,7 @@ export class AgentActivity implements RecognitionHooks {
694760

695761
async onEndOfTurn(info: EndOfTurnInfo): Promise<boolean> {
696762
if (this.draining) {
763+
this.cancelPreemptiveGeneration();
697764
this.logger.warn({ user_input: info.newTranscript }, 'skipping user input, task is draining');
698765
// copied from python:
699766
// TODO(shubhra): should we "forward" this new turn to the next agent/activity?
@@ -710,6 +777,7 @@ export class AgentActivity implements RecognitionHooks {
710777
info.newTranscript.split(' ').length < this.agentSession.options.minInterruptionWords
711778
) {
712779
// avoid interruption if the new_transcript is too short
780+
this.cancelPreemptiveGeneration();
713781
this.logger.info('skipping user input, new_transcript is too short');
714782
return false;
715783
}
@@ -775,13 +843,15 @@ export class AgentActivity implements RecognitionHooks {
775843
instructions?: string;
776844
toolChoice?: ToolChoice | null;
777845
allowInterruptions?: boolean;
846+
scheduleSpeech?: boolean;
778847
}): SpeechHandle {
779848
const {
780849
userMessage,
781850
chatCtx,
782851
instructions: defaultInstructions,
783852
toolChoice: defaultToolChoice,
784853
allowInterruptions: defaultAllowInterruptions,
854+
scheduleSpeech = true,
785855
} = options;
786856

787857
let instructions = defaultInstructions;
@@ -871,7 +941,9 @@ export class AgentActivity implements RecognitionHooks {
871941
task.finally(() => this.onPipelineReplyDone());
872942
}
873943

874-
this.scheduleSpeech(handle, SpeechHandle.SPEECH_PRIORITY_NORMAL);
944+
if (scheduleSpeech) {
945+
this.scheduleSpeech(handle, SpeechHandle.SPEECH_PRIORITY_NORMAL);
946+
}
875947
return handle;
876948
}
877949

@@ -977,16 +1049,48 @@ export class AgentActivity implements RecognitionHooks {
9771049
return;
9781050
}
9791051

980-
// Ensure the new message is passed to generateReply
981-
// This preserves the original message id, making it easier for users to track responses
982-
const speechHandle = this.generateReply({ userMessage, chatCtx });
1052+
let speechHandle: SpeechHandle | undefined;
1053+
if (this._preemptiveGeneration !== undefined) {
1054+
const preemptive = this._preemptiveGeneration;
1055+
// make sure the onUserTurnCompleted didn't change some request parameters
1056+
// otherwise invalidate the preemptive generation
1057+
if (
1058+
preemptive.info.newTranscript === userMessage?.textContent &&
1059+
preemptive.chatCtx.isEquivalent(chatCtx) &&
1060+
isSameToolContext(preemptive.tools, this.tools) &&
1061+
isSameToolChoice(preemptive.toolChoice, this.toolChoice)
1062+
) {
1063+
speechHandle = preemptive.speechHandle;
1064+
this.scheduleSpeech(speechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL);
1065+
this.logger.debug(
1066+
{
1067+
preemptiveLeadTime: Date.now() - preemptive.createdAt,
1068+
},
1069+
'using preemptive generation',
1070+
);
1071+
} else {
1072+
this.logger.warn(
1073+
'preemptive generation enabled but chat context or tools have changed after `onUserTurnCompleted`',
1074+
);
1075+
preemptive.speechHandle._cancel();
1076+
}
1077+
1078+
this._preemptiveGeneration = undefined;
1079+
}
1080+
1081+
if (speechHandle === undefined) {
1082+
// Ensure the new message is passed to generateReply
1083+
// This preserves the original message id, making it easier for users to track responses
1084+
speechHandle = this.generateReply({ userMessage, chatCtx });
1085+
}
9831086

9841087
const eouMetrics: EOUMetrics = {
9851088
type: 'eou_metrics',
9861089
timestamp: Date.now(),
9871090
endOfUtteranceDelayMs: info.endOfUtteranceDelay,
9881091
transcriptionDelayMs: info.transcriptionDelay,
9891092
onUserTurnCompletedDelayMs: callbackDuration,
1093+
lastSpeakingTimeMs: info.stoppedSpeakingAt ?? 0,
9901094
speechId: speechHandle.id,
9911095
};
9921096

@@ -1139,10 +1243,9 @@ export class AgentActivity implements RecognitionHooks {
11391243

11401244
chatCtx = chatCtx.copy();
11411245

1246+
// Insert new message into temporary chat context for LLM inference
11421247
if (newMessage) {
11431248
chatCtx.insert(newMessage);
1144-
this.agent._chatCtx.insert(newMessage);
1145-
this.agentSession._conversationItemAdded(newMessage);
11461249
}
11471250

11481251
if (instructions) {
@@ -1157,7 +1260,6 @@ export class AgentActivity implements RecognitionHooks {
11571260
}
11581261
}
11591262

1160-
this.agentSession._updateAgentState('thinking');
11611263
const tasks: Array<Task<void>> = [];
11621264
const [llmTask, llmGenData] = performLLMInference(
11631265
// preserve `this` context in llmNode
@@ -1185,6 +1287,12 @@ export class AgentActivity implements RecognitionHooks {
11851287

11861288
await speechHandle.waitIfNotInterrupted([speechHandle._waitForScheduled()]);
11871289

1290+
// Add new message to actual chat context if the speech is scheduled
1291+
if (newMessage && speechHandle.scheduled) {
1292+
this.agent._chatCtx.insert(newMessage);
1293+
this.agentSession._conversationItemAdded(newMessage);
1294+
}
1295+
11881296
if (speechHandle.interrupted) {
11891297
replyAbortController.abort();
11901298
await cancelAndWait(tasks, AgentActivity.REPLY_TASK_CANCEL_TIMEOUT);
@@ -1917,6 +2025,7 @@ export class AgentActivity implements RecognitionHooks {
19172025
try {
19182026
if (this._draining) return;
19192027

2028+
this.cancelPreemptiveGeneration();
19202029
this.createSpeechTask({
19212030
task: Task.from(() => this.agent.onExit()),
19222031
name: 'AgentActivity_onExit',
@@ -1937,6 +2046,7 @@ export class AgentActivity implements RecognitionHooks {
19372046
this.logger.warn('task closing without draining');
19382047
}
19392048

2049+
this.cancelPreemptiveGeneration();
19402050
// Unregister event handlers to prevent duplicate metrics
19412051
if (this.llm instanceof LLM) {
19422052
this.llm.off('metrics_collected', this.onMetricsCollected);

agents/src/voice/agent_session.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export interface VoiceOptions {
5757
minEndpointingDelay: number;
5858
maxEndpointingDelay: number;
5959
maxToolSteps: number;
60+
preemptiveGeneration: boolean;
6061
}
6162

6263
const defaultVoiceOptions: VoiceOptions = {
@@ -67,6 +68,7 @@ const defaultVoiceOptions: VoiceOptions = {
6768
minEndpointingDelay: 500,
6869
maxEndpointingDelay: 6000,
6970
maxToolSteps: 3,
71+
preemptiveGeneration: false,
7072
} as const;
7173

7274
export type TurnDetectionMode = 'stt' | 'vad' | 'realtime_llm' | 'manual' | _TurnDetector;
@@ -421,7 +423,7 @@ export class AgentSession<
421423
}
422424

423425
/** @internal */
424-
_updateUserState(state: UserState) {
426+
_updateUserState(state: UserState, _lastSpeakingTime?: number) {
425427
if (this.userState === state) {
426428
return;
427429
}

0 commit comments

Comments
 (0)