|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import torch |
| 4 | +from tensorrt_llm.runtime import ModelRunnerCpp |
| 5 | +from tensorrt_llm.bindings import GptJsonConfig |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from whisper_utils import log_mel_spectrogram, get_tokenizer |
| 9 | +import evaluate |
| 10 | +from normalizer import data_utils |
| 11 | +import time |
| 12 | +from tqdm import tqdm |
| 13 | +from pathlib import Path |
| 14 | +import re |
| 15 | +from concurrent.futures import ThreadPoolExecutor |
| 16 | + |
| 17 | +wer_metric = evaluate.load("wer") |
| 18 | + |
| 19 | +class WhisperTRTLLM(object): |
| 20 | + |
| 21 | + def __init__(self, |
| 22 | + engine_dir, |
| 23 | + assets_dir="assets", |
| 24 | + batch_size=64): |
| 25 | + tokenizer_name = "multilingual" |
| 26 | + assert (Path(assets_dir) / "multilingual.tiktoken").exists( |
| 27 | + ), "multilingual.tiktoken file is not existed in assets_dir" |
| 28 | + |
| 29 | + self.tokenizer = get_tokenizer(name=tokenizer_name, |
| 30 | + num_languages=100, |
| 31 | + tokenizer_dir=assets_dir) |
| 32 | + self.eot_id = self.tokenizer.encode( |
| 33 | + "<|endoftext|>", |
| 34 | + allowed_special=self.tokenizer.special_tokens_set)[0] |
| 35 | + json_config = GptJsonConfig.parse_file(Path(engine_dir) / 'decoder' / 'config.json') |
| 36 | + assert json_config.model_config.supports_inflight_batching |
| 37 | + runner_kwargs = dict(engine_dir=engine_dir, |
| 38 | + is_enc_dec=True, |
| 39 | + max_batch_size=batch_size, |
| 40 | + max_input_len=3000, |
| 41 | + max_output_len=96, |
| 42 | + max_beam_width=1, |
| 43 | + debug_mode=False, |
| 44 | + kv_cache_free_gpu_memory_fraction=0.9) |
| 45 | + self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs) |
| 46 | + self.n_mels = 128 |
| 47 | + |
| 48 | + def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths, max_new_tokens): |
| 49 | + outputs = self.model_runner_cpp.generate( |
| 50 | + batch_input_ids=decoder_input_ids, |
| 51 | + encoder_input_features=mel_batch, |
| 52 | + encoder_output_lengths=mel_input_lengths // 2, |
| 53 | + max_new_tokens=max_new_tokens, |
| 54 | + end_id=self.eot_id, |
| 55 | + pad_id=self.eot_id, |
| 56 | + num_beams=1, |
| 57 | + output_sequence_lengths=True, |
| 58 | + return_dict=True |
| 59 | + ) |
| 60 | + |
| 61 | + output_ids = outputs['output_ids'].cpu().numpy().tolist() |
| 62 | + texts = [] |
| 63 | + for i in range(len(output_ids)): |
| 64 | + text = self.tokenizer.decode(output_ids[i][0]).strip() |
| 65 | + text = re.sub(r'<\|.*?\|>', '', text) |
| 66 | + texts.append(text) |
| 67 | + return texts |
| 68 | + |
| 69 | + def process_batch(self, mel, mel_input_lengths, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", num_threads=4, max_new_tokens=96): |
| 70 | + prompt_id = self.tokenizer.encode( |
| 71 | + text_prefix, allowed_special=self.tokenizer.special_tokens_set) |
| 72 | + prompt_id = torch.tensor(prompt_id) |
| 73 | + batch_size = len(mel) |
| 74 | + decoder_input_ids = prompt_id.repeat(batch_size, 1) |
| 75 | + |
| 76 | + with torch.no_grad(): |
| 77 | + if isinstance(mel, list): |
| 78 | + mel = torch.stack([m.transpose(1, 2).type(torch.float16).squeeze(0) for m in mel]) |
| 79 | + else: |
| 80 | + mel = mel.transpose(1, 2) |
| 81 | + |
| 82 | + num_threads = min(num_threads, batch_size) |
| 83 | + mel_batches = torch.split(mel, batch_size // num_threads) |
| 84 | + mel_input_lengths_batches = torch.split(mel_input_lengths, batch_size // num_threads) |
| 85 | + |
| 86 | + texts_list = [] |
| 87 | + with ThreadPoolExecutor(max_workers=num_threads) as executor: |
| 88 | + futures = [] |
| 89 | + for i, mel_batch in enumerate(mel_batches): |
| 90 | + current_length = mel_batch.size(0) |
| 91 | + futures.append(executor.submit( |
| 92 | + self.process_single_batch, |
| 93 | + mel_batch, |
| 94 | + decoder_input_ids[:current_length], |
| 95 | + mel_input_lengths_batches[i], |
| 96 | + max_new_tokens |
| 97 | + )) |
| 98 | + |
| 99 | + for future in futures: |
| 100 | + texts_list.extend(future.result()) |
| 101 | + |
| 102 | + return texts_list |
| 103 | + |
| 104 | +def longest_common_substring(s1, s2): |
| 105 | + len1, len2 = len(s1), len(s2) |
| 106 | + dp = [[0] * (len2 + 1) for _ in range(len1 + 1)] |
| 107 | + |
| 108 | + longest_length = 0 |
| 109 | + end_index_s1 = 0 |
| 110 | + |
| 111 | + for i in range(1, len1 + 1): |
| 112 | + for j in range(1, len2 + 1): |
| 113 | + if s1[i - 1] == s2[j - 1]: |
| 114 | + dp[i][j] = dp[i - 1][j - 1] + 1 |
| 115 | + if dp[i][j] > longest_length: |
| 116 | + longest_length = dp[i][j] |
| 117 | + end_index_s1 = i |
| 118 | + else: |
| 119 | + dp[i][j] = 0 |
| 120 | + |
| 121 | + return s1[end_index_s1 - longest_length:end_index_s1] |
| 122 | + |
| 123 | +def chunk_audio(audio, chunk_length, overlap_length, sample_rate): |
| 124 | + chunk_size = int(chunk_length * sample_rate) |
| 125 | + overlap_size = int(overlap_length * sample_rate) |
| 126 | + |
| 127 | + chunks = [] |
| 128 | + start = 0 |
| 129 | + |
| 130 | + while start < len(audio): |
| 131 | + end = min(start + chunk_size, len(audio)) |
| 132 | + chunks.append(audio[start:end]) |
| 133 | + start += chunk_size - overlap_size |
| 134 | + |
| 135 | + return chunks |
| 136 | + |
| 137 | +def main(args): |
| 138 | + asr_model = WhisperTRTLLM(engine_dir=args.model_id) |
| 139 | + |
| 140 | + def benchmark(batch, min_new_tokens=None): |
| 141 | + # Load audio inputs |
| 142 | + max_duration, sample_rate = 30, 16000 |
| 143 | + audios_origin = [audio["array"].astype(np.float32) for audio in batch["audio"]] |
| 144 | + minibatch_size = len(audios_origin) |
| 145 | + audios, audio_index = [], [] |
| 146 | + |
| 147 | + chunk_length = 25 |
| 148 | + overlap_length = 5 |
| 149 | + for i, audio in enumerate(audios_origin): |
| 150 | + if len(audio) > max_duration * sample_rate: |
| 151 | + audio_chunks = chunk_audio(audio, chunk_length, overlap_length, sample_rate) |
| 152 | + for chunk in audio_chunks: |
| 153 | + audios.append(chunk) |
| 154 | + audio_index.append(i) |
| 155 | + else: |
| 156 | + audios.append(audio) |
| 157 | + audio_index.append(i) |
| 158 | + audios = [torch.from_numpy(audio) for audio in audios] |
| 159 | + |
| 160 | + # START TIMING |
| 161 | + start_time = time.time() |
| 162 | + longest_duration = int(sample_rate * max_duration) |
| 163 | + |
| 164 | + features = [ |
| 165 | + log_mel_spectrogram(wave, |
| 166 | + asr_model.n_mels, |
| 167 | + padding=longest_duration - wave.shape[-1], |
| 168 | + device='cuda').unsqueeze(0) |
| 169 | + for wave in audios |
| 170 | + ] |
| 171 | + |
| 172 | + features_input_lengths = torch.tensor([f.shape[2] for f in features], |
| 173 | + dtype=torch.int32, |
| 174 | + device='cuda') |
| 175 | + |
| 176 | + texts_origin = asr_model.process_batch(features, features_input_lengths, num_threads=4) |
| 177 | + |
| 178 | + texts = [] |
| 179 | + for i in range(minibatch_size): |
| 180 | + text_chunks = [] |
| 181 | + for j in range(len(texts_origin)): |
| 182 | + if audio_index[j] == i: |
| 183 | + text_chunks.append(texts_origin[j]) |
| 184 | + |
| 185 | + if len(text_chunks) > 1: |
| 186 | + merged_text = text_chunks[0] |
| 187 | + for t in text_chunks[1:]: |
| 188 | + lcs = longest_common_substring(merged_text, t) |
| 189 | + merged_text += t[len(lcs):] |
| 190 | + |
| 191 | + texts.append(merged_text) |
| 192 | + else: |
| 193 | + texts.append(text_chunks[0]) |
| 194 | + # END TIMING |
| 195 | + runtime = time.time() - start_time |
| 196 | + |
| 197 | + print(f"Batch size: {minibatch_size}, Time taken: {runtime:.2f} s, texts_origin_len: {len(texts_origin)}, texts_len: {len(texts)}") |
| 198 | + # normalize by minibatch size since we want the per-sample time |
| 199 | + batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size] |
| 200 | + |
| 201 | + # normalize transcriptions with English normalizer |
| 202 | + batch["predictions"] = [data_utils.normalizer(pred) for pred in texts] |
| 203 | + batch["references"] = batch["norm_text"] |
| 204 | + return batch |
| 205 | + |
| 206 | + if args.warmup_steps is not None: |
| 207 | + dataset = data_utils.load_data(args) |
| 208 | + dataset = data_utils.prepare_data(dataset) |
| 209 | + |
| 210 | + num_warmup_samples = args.warmup_steps * args.batch_size |
| 211 | + if args.streaming: |
| 212 | + warmup_dataset = dataset.take(num_warmup_samples) |
| 213 | + else: |
| 214 | + warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset)))) |
| 215 | + warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens})) |
| 216 | + |
| 217 | + for _ in tqdm(warmup_dataset, desc="Warming up..."): |
| 218 | + continue |
| 219 | + |
| 220 | + dataset = data_utils.load_data(args) |
| 221 | + if args.max_eval_samples is not None and args.max_eval_samples > 0: |
| 222 | + print(f"Subsampling dataset to first {args.max_eval_samples} samples!") |
| 223 | + if args.streaming: |
| 224 | + dataset = dataset.take(args.max_eval_samples) |
| 225 | + else: |
| 226 | + dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) |
| 227 | + dataset = data_utils.prepare_data(dataset) |
| 228 | + |
| 229 | + dataset = dataset.map( |
| 230 | + benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"], |
| 231 | + ) |
| 232 | + |
| 233 | + all_results = { |
| 234 | + "audio_length_s": [], |
| 235 | + "transcription_time_s": [], |
| 236 | + "predictions": [], |
| 237 | + "references": [], |
| 238 | + } |
| 239 | + result_iter = iter(dataset) |
| 240 | + for result in tqdm(result_iter, desc="Samples..."): |
| 241 | + for key in all_results: |
| 242 | + all_results[key].append(result[key]) |
| 243 | + |
| 244 | + # Write manifest results (WER and RTFX) |
| 245 | + manifest_path = data_utils.write_manifest( |
| 246 | + all_results["references"], |
| 247 | + all_results["predictions"], |
| 248 | + args.model_id, |
| 249 | + args.dataset_path, |
| 250 | + args.dataset, |
| 251 | + args.split, |
| 252 | + audio_length=all_results["audio_length_s"], |
| 253 | + transcription_time=all_results["transcription_time_s"], |
| 254 | + ) |
| 255 | + print("Results saved at path:", os.path.abspath(manifest_path)) |
| 256 | + |
| 257 | + wer = wer_metric.compute( |
| 258 | + references=all_results["references"], predictions=all_results["predictions"] |
| 259 | + ) |
| 260 | + wer = round(100 * wer, 2) |
| 261 | + rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2) |
| 262 | + print("WER:", wer, "%", "RTFx:", rtfx) |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == "__main__": |
| 266 | + parser = argparse.ArgumentParser() |
| 267 | + |
| 268 | + parser.add_argument( |
| 269 | + "--model_id", |
| 270 | + type=str, |
| 271 | + required=True, |
| 272 | + help="Model identifier. Should be loadable with 🤗 Transformers", |
| 273 | + ) |
| 274 | + parser.add_argument( |
| 275 | + "--dataset_path", |
| 276 | + type=str, |
| 277 | + default="esb/datasets", |
| 278 | + help="Dataset path. By default, it is `esb/datasets`", |
| 279 | + ) |
| 280 | + parser.add_argument( |
| 281 | + "--dataset", |
| 282 | + type=str, |
| 283 | + required=True, |
| 284 | + help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " |
| 285 | + "can be found at `https://huggingface.co/datasets/esb/datasets`", |
| 286 | + ) |
| 287 | + parser.add_argument( |
| 288 | + "--split", |
| 289 | + type=str, |
| 290 | + default="test", |
| 291 | + help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", |
| 292 | + ) |
| 293 | + parser.add_argument( |
| 294 | + "--device", |
| 295 | + type=int, |
| 296 | + default=-1, |
| 297 | + help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", |
| 298 | + ) |
| 299 | + parser.add_argument( |
| 300 | + "--batch_size", |
| 301 | + type=int, |
| 302 | + default=16, |
| 303 | + help="Number of samples to go through each streamed batch.", |
| 304 | + ) |
| 305 | + parser.add_argument( |
| 306 | + "--max_eval_samples", |
| 307 | + type=int, |
| 308 | + default=None, |
| 309 | + help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
| 310 | + ) |
| 311 | + parser.add_argument( |
| 312 | + "--no-streaming", |
| 313 | + dest="streaming", |
| 314 | + action="store_false", |
| 315 | + help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", |
| 316 | + ) |
| 317 | + parser.add_argument( |
| 318 | + "--max_new_tokens", |
| 319 | + type=int, |
| 320 | + default=None, |
| 321 | + help="Maximum number of tokens to generate (for auto-regressive models).", |
| 322 | + ) |
| 323 | + parser.add_argument( |
| 324 | + "--warmup_steps", |
| 325 | + type=int, |
| 326 | + default=10, |
| 327 | + help="Number of warm-up steps to run before launching the timed runs.", |
| 328 | + ) |
| 329 | + args = parser.parse_args() |
| 330 | + parser.set_defaults(streaming=False) |
| 331 | + |
| 332 | + main(args) |
0 commit comments