Skip to content

Commit a39bd25

Browse files
yuekaizhangYuekai Zhang
andauthored
[Ready] Add whisper TensorRT-LLM (#42)
* add whisper trt-llm * add vad module * remove vad * remove vad files * remove convert_checkpoint * code clean --------- Co-authored-by: Yuekai Zhang <[email protected]>
1 parent 9bdea31 commit a39bd25

File tree

4 files changed

+734
-0
lines changed

4 files changed

+734
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
tiktoken
2+
jiwer
3+
tensorrt-llm==0.15.0.dev2024101500
4+
soundfile
5+
librosa

tensorrtllm/run_eval.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import argparse
2+
import os
3+
import torch
4+
from tensorrt_llm.runtime import ModelRunnerCpp
5+
from tensorrt_llm.bindings import GptJsonConfig
6+
import numpy as np
7+
8+
from whisper_utils import log_mel_spectrogram, get_tokenizer
9+
import evaluate
10+
from normalizer import data_utils
11+
import time
12+
from tqdm import tqdm
13+
from pathlib import Path
14+
import re
15+
from concurrent.futures import ThreadPoolExecutor
16+
17+
wer_metric = evaluate.load("wer")
18+
19+
class WhisperTRTLLM(object):
20+
21+
def __init__(self,
22+
engine_dir,
23+
assets_dir="assets",
24+
batch_size=64):
25+
tokenizer_name = "multilingual"
26+
assert (Path(assets_dir) / "multilingual.tiktoken").exists(
27+
), "multilingual.tiktoken file is not existed in assets_dir"
28+
29+
self.tokenizer = get_tokenizer(name=tokenizer_name,
30+
num_languages=100,
31+
tokenizer_dir=assets_dir)
32+
self.eot_id = self.tokenizer.encode(
33+
"<|endoftext|>",
34+
allowed_special=self.tokenizer.special_tokens_set)[0]
35+
json_config = GptJsonConfig.parse_file(Path(engine_dir) / 'decoder' / 'config.json')
36+
assert json_config.model_config.supports_inflight_batching
37+
runner_kwargs = dict(engine_dir=engine_dir,
38+
is_enc_dec=True,
39+
max_batch_size=batch_size,
40+
max_input_len=3000,
41+
max_output_len=96,
42+
max_beam_width=1,
43+
debug_mode=False,
44+
kv_cache_free_gpu_memory_fraction=0.9)
45+
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
46+
self.n_mels = 128
47+
48+
def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths, max_new_tokens):
49+
outputs = self.model_runner_cpp.generate(
50+
batch_input_ids=decoder_input_ids,
51+
encoder_input_features=mel_batch,
52+
encoder_output_lengths=mel_input_lengths // 2,
53+
max_new_tokens=max_new_tokens,
54+
end_id=self.eot_id,
55+
pad_id=self.eot_id,
56+
num_beams=1,
57+
output_sequence_lengths=True,
58+
return_dict=True
59+
)
60+
61+
output_ids = outputs['output_ids'].cpu().numpy().tolist()
62+
texts = []
63+
for i in range(len(output_ids)):
64+
text = self.tokenizer.decode(output_ids[i][0]).strip()
65+
text = re.sub(r'<\|.*?\|>', '', text)
66+
texts.append(text)
67+
return texts
68+
69+
def process_batch(self, mel, mel_input_lengths, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", num_threads=4, max_new_tokens=96):
70+
prompt_id = self.tokenizer.encode(
71+
text_prefix, allowed_special=self.tokenizer.special_tokens_set)
72+
prompt_id = torch.tensor(prompt_id)
73+
batch_size = len(mel)
74+
decoder_input_ids = prompt_id.repeat(batch_size, 1)
75+
76+
with torch.no_grad():
77+
if isinstance(mel, list):
78+
mel = torch.stack([m.transpose(1, 2).type(torch.float16).squeeze(0) for m in mel])
79+
else:
80+
mel = mel.transpose(1, 2)
81+
82+
num_threads = min(num_threads, batch_size)
83+
mel_batches = torch.split(mel, batch_size // num_threads)
84+
mel_input_lengths_batches = torch.split(mel_input_lengths, batch_size // num_threads)
85+
86+
texts_list = []
87+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
88+
futures = []
89+
for i, mel_batch in enumerate(mel_batches):
90+
current_length = mel_batch.size(0)
91+
futures.append(executor.submit(
92+
self.process_single_batch,
93+
mel_batch,
94+
decoder_input_ids[:current_length],
95+
mel_input_lengths_batches[i],
96+
max_new_tokens
97+
))
98+
99+
for future in futures:
100+
texts_list.extend(future.result())
101+
102+
return texts_list
103+
104+
def longest_common_substring(s1, s2):
105+
len1, len2 = len(s1), len(s2)
106+
dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
107+
108+
longest_length = 0
109+
end_index_s1 = 0
110+
111+
for i in range(1, len1 + 1):
112+
for j in range(1, len2 + 1):
113+
if s1[i - 1] == s2[j - 1]:
114+
dp[i][j] = dp[i - 1][j - 1] + 1
115+
if dp[i][j] > longest_length:
116+
longest_length = dp[i][j]
117+
end_index_s1 = i
118+
else:
119+
dp[i][j] = 0
120+
121+
return s1[end_index_s1 - longest_length:end_index_s1]
122+
123+
def chunk_audio(audio, chunk_length, overlap_length, sample_rate):
124+
chunk_size = int(chunk_length * sample_rate)
125+
overlap_size = int(overlap_length * sample_rate)
126+
127+
chunks = []
128+
start = 0
129+
130+
while start < len(audio):
131+
end = min(start + chunk_size, len(audio))
132+
chunks.append(audio[start:end])
133+
start += chunk_size - overlap_size
134+
135+
return chunks
136+
137+
def main(args):
138+
asr_model = WhisperTRTLLM(engine_dir=args.model_id)
139+
140+
def benchmark(batch, min_new_tokens=None):
141+
# Load audio inputs
142+
max_duration, sample_rate = 30, 16000
143+
audios_origin = [audio["array"].astype(np.float32) for audio in batch["audio"]]
144+
minibatch_size = len(audios_origin)
145+
audios, audio_index = [], []
146+
147+
chunk_length = 25
148+
overlap_length = 5
149+
for i, audio in enumerate(audios_origin):
150+
if len(audio) > max_duration * sample_rate:
151+
audio_chunks = chunk_audio(audio, chunk_length, overlap_length, sample_rate)
152+
for chunk in audio_chunks:
153+
audios.append(chunk)
154+
audio_index.append(i)
155+
else:
156+
audios.append(audio)
157+
audio_index.append(i)
158+
audios = [torch.from_numpy(audio) for audio in audios]
159+
160+
# START TIMING
161+
start_time = time.time()
162+
longest_duration = int(sample_rate * max_duration)
163+
164+
features = [
165+
log_mel_spectrogram(wave,
166+
asr_model.n_mels,
167+
padding=longest_duration - wave.shape[-1],
168+
device='cuda').unsqueeze(0)
169+
for wave in audios
170+
]
171+
172+
features_input_lengths = torch.tensor([f.shape[2] for f in features],
173+
dtype=torch.int32,
174+
device='cuda')
175+
176+
texts_origin = asr_model.process_batch(features, features_input_lengths, num_threads=4)
177+
178+
texts = []
179+
for i in range(minibatch_size):
180+
text_chunks = []
181+
for j in range(len(texts_origin)):
182+
if audio_index[j] == i:
183+
text_chunks.append(texts_origin[j])
184+
185+
if len(text_chunks) > 1:
186+
merged_text = text_chunks[0]
187+
for t in text_chunks[1:]:
188+
lcs = longest_common_substring(merged_text, t)
189+
merged_text += t[len(lcs):]
190+
191+
texts.append(merged_text)
192+
else:
193+
texts.append(text_chunks[0])
194+
# END TIMING
195+
runtime = time.time() - start_time
196+
197+
print(f"Batch size: {minibatch_size}, Time taken: {runtime:.2f} s, texts_origin_len: {len(texts_origin)}, texts_len: {len(texts)}")
198+
# normalize by minibatch size since we want the per-sample time
199+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
200+
201+
# normalize transcriptions with English normalizer
202+
batch["predictions"] = [data_utils.normalizer(pred) for pred in texts]
203+
batch["references"] = batch["norm_text"]
204+
return batch
205+
206+
if args.warmup_steps is not None:
207+
dataset = data_utils.load_data(args)
208+
dataset = data_utils.prepare_data(dataset)
209+
210+
num_warmup_samples = args.warmup_steps * args.batch_size
211+
if args.streaming:
212+
warmup_dataset = dataset.take(num_warmup_samples)
213+
else:
214+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
215+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
216+
217+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
218+
continue
219+
220+
dataset = data_utils.load_data(args)
221+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
222+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
223+
if args.streaming:
224+
dataset = dataset.take(args.max_eval_samples)
225+
else:
226+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
227+
dataset = data_utils.prepare_data(dataset)
228+
229+
dataset = dataset.map(
230+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
231+
)
232+
233+
all_results = {
234+
"audio_length_s": [],
235+
"transcription_time_s": [],
236+
"predictions": [],
237+
"references": [],
238+
}
239+
result_iter = iter(dataset)
240+
for result in tqdm(result_iter, desc="Samples..."):
241+
for key in all_results:
242+
all_results[key].append(result[key])
243+
244+
# Write manifest results (WER and RTFX)
245+
manifest_path = data_utils.write_manifest(
246+
all_results["references"],
247+
all_results["predictions"],
248+
args.model_id,
249+
args.dataset_path,
250+
args.dataset,
251+
args.split,
252+
audio_length=all_results["audio_length_s"],
253+
transcription_time=all_results["transcription_time_s"],
254+
)
255+
print("Results saved at path:", os.path.abspath(manifest_path))
256+
257+
wer = wer_metric.compute(
258+
references=all_results["references"], predictions=all_results["predictions"]
259+
)
260+
wer = round(100 * wer, 2)
261+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
262+
print("WER:", wer, "%", "RTFx:", rtfx)
263+
264+
265+
if __name__ == "__main__":
266+
parser = argparse.ArgumentParser()
267+
268+
parser.add_argument(
269+
"--model_id",
270+
type=str,
271+
required=True,
272+
help="Model identifier. Should be loadable with 🤗 Transformers",
273+
)
274+
parser.add_argument(
275+
"--dataset_path",
276+
type=str,
277+
default="esb/datasets",
278+
help="Dataset path. By default, it is `esb/datasets`",
279+
)
280+
parser.add_argument(
281+
"--dataset",
282+
type=str,
283+
required=True,
284+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
285+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
286+
)
287+
parser.add_argument(
288+
"--split",
289+
type=str,
290+
default="test",
291+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
292+
)
293+
parser.add_argument(
294+
"--device",
295+
type=int,
296+
default=-1,
297+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
298+
)
299+
parser.add_argument(
300+
"--batch_size",
301+
type=int,
302+
default=16,
303+
help="Number of samples to go through each streamed batch.",
304+
)
305+
parser.add_argument(
306+
"--max_eval_samples",
307+
type=int,
308+
default=None,
309+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
310+
)
311+
parser.add_argument(
312+
"--no-streaming",
313+
dest="streaming",
314+
action="store_false",
315+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
316+
)
317+
parser.add_argument(
318+
"--max_new_tokens",
319+
type=int,
320+
default=None,
321+
help="Maximum number of tokens to generate (for auto-regressive models).",
322+
)
323+
parser.add_argument(
324+
"--warmup_steps",
325+
type=int,
326+
default=10,
327+
help="Number of warm-up steps to run before launching the timed runs.",
328+
)
329+
args = parser.parse_args()
330+
parser.set_defaults(streaming=False)
331+
332+
main(args)

0 commit comments

Comments
 (0)