11import argparse
22import os
33import torch
4+ import json
45from tensorrt_llm .runtime import ModelRunnerCpp
56from tensorrt_llm .bindings import GptJsonConfig
67import numpy as np
7-
8+ from collections import OrderedDict
9+ from pathlib import Path
810from whisper_utils import log_mel_spectrogram , get_tokenizer
911import evaluate
1012from normalizer import data_utils
1618
1719wer_metric = evaluate .load ("wer" )
1820
21+ def read_config (component , engine_dir ):
22+ engine_dir = Path (engine_dir )
23+ config_path = engine_dir / component / 'config.json'
24+ with open (config_path , 'r' ) as f :
25+ config = json .load (f )
26+ model_config = OrderedDict ()
27+ model_config .update (config ['pretrained_config' ])
28+ model_config .update (config ['build_config' ])
29+ return model_config
30+
1931class WhisperTRTLLM (object ):
2032
2133 def __init__ (self ,
2234 engine_dir ,
2335 assets_dir = "assets" ,
2436 batch_size = 64 ):
25- tokenizer_name = "multilingual"
26- assert (Path (assets_dir ) / "multilingual.tiktoken" ).exists (
27- ), "multilingual.tiktoken file is not existed in assets_dir"
28-
37+ encoder_config = read_config ('encoder' , engine_dir )
38+ decoder_config = read_config ('decoder' , engine_dir )
39+ self .n_mels = encoder_config ['n_mels' ]
40+ self .num_languages = encoder_config ['num_languages' ]
41+ is_multilingual = (decoder_config ['vocab_size' ] >= 51865 )
42+ if is_multilingual :
43+ tokenizer_name = "multilingual"
44+ assert (Path (assets_dir ) / "multilingual.tiktoken" ).exists (
45+ ), "multilingual.tiktoken file is not existed in assets_dir"
46+ else :
47+ tokenizer_name = "gpt2"
48+ assert (Path (assets_dir ) / "gpt2.tiktoken" ).exists (
49+ ), "gpt2.tiktoken file is not existed in assets_dir"
50+ self .text_prefix = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" if is_multilingual else "<|startoftranscript|><|notimestamps|>"
2951 self .tokenizer = get_tokenizer (name = tokenizer_name ,
30- num_languages = 100 ,
52+ num_languages = self . num_languages ,
3153 tokenizer_dir = assets_dir )
3254 self .eot_id = self .tokenizer .encode (
3355 "<|endoftext|>" ,
@@ -43,7 +65,6 @@ def __init__(self,
4365 debug_mode = False ,
4466 kv_cache_free_gpu_memory_fraction = 0.9 )
4567 self .model_runner_cpp = ModelRunnerCpp .from_dir (** runner_kwargs )
46- self .n_mels = 128
4768
4869 def process_single_batch (self , mel_batch , decoder_input_ids , mel_input_lengths , max_new_tokens ):
4970 outputs = self .model_runner_cpp .generate (
@@ -66,9 +87,9 @@ def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths,
6687 texts .append (text )
6788 return texts
6889
69- def process_batch (self , mel , mel_input_lengths , text_prefix = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" , num_threads = 4 , max_new_tokens = 96 ):
90+ def process_batch (self , mel , mel_input_lengths , num_threads = 4 , max_new_tokens = 96 ):
7091 prompt_id = self .tokenizer .encode (
71- text_prefix , allowed_special = self .tokenizer .special_tokens_set )
92+ self . text_prefix , allowed_special = self .tokenizer .special_tokens_set )
7293 prompt_id = torch .tensor (prompt_id )
7394 batch_size = len (mel )
7495 decoder_input_ids = prompt_id .repeat (batch_size , 1 )
0 commit comments