Skip to content

Commit 067b21c

Browse files
yuekaizhangYuekai Zhang
andauthored
fix whisper .en model (#46)
* fix whisper * add all models --------- Co-authored-by: Yuekai Zhang <[email protected]>
1 parent a39bd25 commit 067b21c

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

tensorrtllm/run_eval.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import argparse
22
import os
33
import torch
4+
import json
45
from tensorrt_llm.runtime import ModelRunnerCpp
56
from tensorrt_llm.bindings import GptJsonConfig
67
import numpy as np
7-
8+
from collections import OrderedDict
9+
from pathlib import Path
810
from whisper_utils import log_mel_spectrogram, get_tokenizer
911
import evaluate
1012
from normalizer import data_utils
@@ -16,18 +18,38 @@
1618

1719
wer_metric = evaluate.load("wer")
1820

21+
def read_config(component, engine_dir):
22+
engine_dir = Path(engine_dir)
23+
config_path = engine_dir / component / 'config.json'
24+
with open(config_path, 'r') as f:
25+
config = json.load(f)
26+
model_config = OrderedDict()
27+
model_config.update(config['pretrained_config'])
28+
model_config.update(config['build_config'])
29+
return model_config
30+
1931
class WhisperTRTLLM(object):
2032

2133
def __init__(self,
2234
engine_dir,
2335
assets_dir="assets",
2436
batch_size=64):
25-
tokenizer_name = "multilingual"
26-
assert (Path(assets_dir) / "multilingual.tiktoken").exists(
27-
), "multilingual.tiktoken file is not existed in assets_dir"
28-
37+
encoder_config = read_config('encoder', engine_dir)
38+
decoder_config = read_config('decoder', engine_dir)
39+
self.n_mels = encoder_config['n_mels']
40+
self.num_languages = encoder_config['num_languages']
41+
is_multilingual = (decoder_config['vocab_size'] >= 51865)
42+
if is_multilingual:
43+
tokenizer_name = "multilingual"
44+
assert (Path(assets_dir) / "multilingual.tiktoken").exists(
45+
), "multilingual.tiktoken file is not existed in assets_dir"
46+
else:
47+
tokenizer_name = "gpt2"
48+
assert (Path(assets_dir) / "gpt2.tiktoken").exists(
49+
), "gpt2.tiktoken file is not existed in assets_dir"
50+
self.text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" if is_multilingual else "<|startoftranscript|><|notimestamps|>"
2951
self.tokenizer = get_tokenizer(name=tokenizer_name,
30-
num_languages=100,
52+
num_languages=self.num_languages,
3153
tokenizer_dir=assets_dir)
3254
self.eot_id = self.tokenizer.encode(
3355
"<|endoftext|>",
@@ -43,7 +65,6 @@ def __init__(self,
4365
debug_mode=False,
4466
kv_cache_free_gpu_memory_fraction=0.9)
4567
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
46-
self.n_mels = 128
4768

4869
def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths, max_new_tokens):
4970
outputs = self.model_runner_cpp.generate(
@@ -66,9 +87,9 @@ def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths,
6687
texts.append(text)
6788
return texts
6889

69-
def process_batch(self, mel, mel_input_lengths, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", num_threads=4, max_new_tokens=96):
90+
def process_batch(self, mel, mel_input_lengths, num_threads=4, max_new_tokens=96):
7091
prompt_id = self.tokenizer.encode(
71-
text_prefix, allowed_special=self.tokenizer.special_tokens_set)
92+
self.text_prefix, allowed_special=self.tokenizer.special_tokens_set)
7293
prompt_id = torch.tensor(prompt_id)
7394
batch_size = len(mel)
7495
decoder_input_ids = prompt_id.repeat(batch_size, 1)

tensorrtllm/run_whisper.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ download_model() {
1111
wget -nc --directory-prefix=assets "$URL"
1212
wget -nc --directory-prefix=assets assets/mel_filters.npz https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz
1313
wget -nc --directory-prefix=assets https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken
14+
wget -nc --directory-prefix=assets https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken
1415

1516
}
1617

@@ -48,7 +49,7 @@ build_model() {
4849
--gpt_attention_plugin "$INFERENCE_PRECISION"
4950
}
5051

51-
MODEL_IDs=("large-v3-turbo" "large-v3")
52+
MODEL_IDs=("large-v3-turbo" "large-v3" "large-v2" "large-v1" "medium" "base" "small" "tiny" "medium.en" "base.en" "small.en" "tiny.en")
5253
DEVICE_INDEX=0
5354
BATCH_SIZE=64
5455

0 commit comments

Comments
 (0)