Skip to content

Commit 6c793a4

Browse files
committed
Add STT and utils
1 parent 159bff8 commit 6c793a4

File tree

5 files changed

+209
-1
lines changed

5 files changed

+209
-1
lines changed

agents/build.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dts from 'bun-plugin-dts';
66

77
await Bun.build({
8-
entrypoints: ['./src/index.ts'],
8+
entrypoints: ['./src/index.ts', './src/tts/index.ts', './src/stt/index.ts'],
99
outdir: './dist',
1010
target: 'node',
1111
sourcemap: 'external',

agents/src/stt/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
export { STT, SpeechEvent, SpeechEventType, SpeechStream } from './stt';
6+
export { StreamAdapter, StreamAdapterWrapper } from './stream_adapter';

agents/src/stt/stream_adapter.ts

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
import { STT, SpeechEvent, SpeechEventType, SpeechStream } from './stt';
6+
import { VADEventType, VADStream } from '../vad';
7+
import { AudioFrame } from '@livekit/rtc-node';
8+
import { AudioBuffer, mergeFrames } from '../utils';
9+
10+
export class StreamAdapterWrapper extends SpeechStream {
11+
closed: boolean;
12+
stt: STT;
13+
vadStream: VADStream;
14+
eventQueue: (SpeechEvent | undefined)[];
15+
language: string | undefined;
16+
task: {
17+
run: Promise<void>;
18+
cancel: () => void;
19+
};
20+
21+
constructor(stt: STT, vadStream: VADStream, language: string | undefined = undefined) {
22+
super();
23+
this.closed = false;
24+
this.stt = stt;
25+
this.vadStream = vadStream;
26+
this.eventQueue = [];
27+
this.language = language;
28+
this.task = {
29+
run: new Promise((_, reject) => {
30+
this.run(reject);
31+
}),
32+
cancel: () => {},
33+
};
34+
}
35+
36+
async run(reject: (arg: Error) => void) {
37+
this.task.cancel = () => {
38+
this.closed = true;
39+
reject(new Error('cancelled'));
40+
};
41+
42+
for (const event of this.vadStream) {
43+
if (event.type == VADEventType.START_OF_SPEECH) {
44+
const startEvent = new SpeechEvent(SpeechEventType.START_OF_SPEECH);
45+
this.eventQueue.push(startEvent);
46+
} else if (event.type == VADEventType.END_OF_SPEECH) {
47+
const mergedFrames = mergeFrames(event.speech);
48+
const endEvent = await this.stt.recognize(mergedFrames, this.language);
49+
this.eventQueue.push(endEvent);
50+
}
51+
}
52+
53+
this.eventQueue.push(undefined);
54+
}
55+
56+
pushFrame(frame: AudioFrame) {
57+
if (this.closed) {
58+
throw new TypeError('cannot push frame to closed stream');
59+
}
60+
61+
this.vadStream.pushFrame(frame);
62+
}
63+
64+
async close(wait: boolean = true): Promise<void> {
65+
this.closed = true;
66+
67+
if (!wait) {
68+
this.task.cancel();
69+
}
70+
71+
await this.vadStream.close(wait);
72+
await this.task.run;
73+
}
74+
75+
next(): IteratorResult<SpeechEvent> {
76+
const item = this.eventQueue.shift();
77+
if (item) {
78+
return { done: false, value: item };
79+
} else {
80+
return { done: true, value: undefined };
81+
}
82+
}
83+
}
84+
85+
export class StreamAdapter extends STT {
86+
stt: STT;
87+
vadStream: VADStream;
88+
89+
constructor(stt: STT, vadStream: VADStream) {
90+
super(true);
91+
this.stt = stt;
92+
this.vadStream = vadStream;
93+
}
94+
95+
async recognize(
96+
buffer: AudioBuffer,
97+
language: string | undefined = undefined,
98+
): Promise<SpeechEvent> {
99+
return await this.stt.recognize(buffer, language);
100+
}
101+
102+
stream(language: string | undefined = undefined) {
103+
return new StreamAdapterWrapper(this.stt, this.vadStream, language);
104+
}
105+
}

agents/src/stt/stt.ts

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
import { AudioFrame } from '@livekit/rtc-node';
6+
import { AudioBuffer } from '../utils';
7+
8+
export enum SpeechEventType {
9+
START_OF_SPEECH = 0,
10+
INTERIM_TRANSCRIPT = 1,
11+
FINAL_TRANSCRIPT = 2,
12+
END_OF_SPEECH = 3,
13+
}
14+
15+
export interface SpeechData {
16+
language: string;
17+
text: string;
18+
startTime: number;
19+
endTime: number;
20+
confidence: number;
21+
}
22+
23+
export class SpeechEvent {
24+
type: SpeechEventType;
25+
alternatives: SpeechData[];
26+
27+
constructor(type: SpeechEventType, alternatives: SpeechData[] = []) {
28+
this.type = type;
29+
this.alternatives = alternatives;
30+
}
31+
}
32+
33+
export abstract class SpeechStream implements IterableIterator<SpeechEvent> {
34+
abstract pushFrame(token: AudioFrame): void;
35+
36+
abstract close(wait: boolean): Promise<void>;
37+
38+
abstract next(): IteratorResult<SpeechEvent>;
39+
40+
[Symbol.iterator](): SpeechStream {
41+
return this;
42+
}
43+
}
44+
45+
export abstract class STT {
46+
#streamingSupported: boolean;
47+
48+
constructor(streamingSupported: boolean) {
49+
this.#streamingSupported = streamingSupported;
50+
}
51+
52+
abstract recognize(buffer: AudioBuffer, language: string | undefined): Promise<SpeechEvent>;
53+
54+
abstract stream(language: string | undefined): SpeechStream;
55+
56+
get streamingSupported(): boolean {
57+
return this.#streamingSupported;
58+
}
59+
}

agents/src/utils.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
import { AudioFrame } from '@livekit/rtc-node';
6+
7+
export type AudioBuffer = AudioFrame[] | AudioFrame;
8+
9+
export const mergeFrames = (buffer: AudioBuffer): AudioFrame => {
10+
if (Array.isArray(buffer)) {
11+
buffer = buffer as AudioFrame[];
12+
if (buffer.length == 0) {
13+
throw new TypeError('buffer is empty');
14+
}
15+
16+
const sampleRate = buffer[0].sampleRate;
17+
const channels = buffer[0].channels;
18+
let samplesPerChannel = 0;
19+
let data = new Uint16Array();
20+
21+
for (const frame of buffer) {
22+
if (frame.sampleRate !== sampleRate) {
23+
throw new TypeError('sample rate mismatch');
24+
}
25+
26+
if (frame.channels !== channels) {
27+
throw new TypeError('channel count mismatch');
28+
}
29+
30+
data = new Uint16Array([...data, ...frame.data]);
31+
samplesPerChannel += frame.samplesPerChannel;
32+
}
33+
34+
return new AudioFrame(data, sampleRate, channels, samplesPerChannel);
35+
}
36+
37+
return buffer;
38+
};

0 commit comments

Comments
 (0)