Skip to content

Commit e55f19f

Browse files
Add OpenAI Whisper support (#125)
Co-Authored-By: MahmoudAshraf97 <[email protected]>
1 parent 92f430f commit e55f19f

File tree

11 files changed

+319
-96
lines changed

11 files changed

+319
-96
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Model Optimizer Changelog (Linux)
1010

1111
**New Features**
1212

13+
- New model support in the ``llm_ptq`` example: OpenAI Whisper.
1314
- Blockwise FP8 quantization support in unified model export.
1415
- Add quantization support to the Transformer Engine Linear module.
1516
- Add support for SVDQuant. Currently, only simulation is available; real deployment (for example, TensorRT deployment) support is coming soon.

examples/llm_ptq/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ InternLM2 | Yes | No | Yes | Yes<sup>3</sup> | -
115115
Exaone | Yes | Yes | Yes | Yes | -
116116
Minitron | Yes | Yes | Yes | Yes<sup>2</sup> | Yes
117117
T5 | Yes | Yes | Yes | Yes | -
118+
Whisper | Yes | No | No | No | -
118119

119120
> *<sup>1.</sup>The w4a8_awq is an experimental quantization scheme that may result in a higher accuracy penalty.*
120121

examples/llm_ptq/example_utils.py

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,6 @@
2121

2222
from modelopt.torch.utils.image_processor import MllamaImageProcessor
2323

24-
MODEL_NAME_PATTERN_MAP = {
25-
"GPT2": "gpt",
26-
"Mllama": "mllama",
27-
"Llama": "llama",
28-
"Mistral": "llama",
29-
"GPTJ": "gptj",
30-
"FalconForCausalLM": "falcon",
31-
"RWForCausalLM": "falcon",
32-
"baichuan": "baichuan",
33-
"MPT": "mpt",
34-
"Bloom": "bloom",
35-
"ChatGLM": "chatglm",
36-
"QWen": "qwen",
37-
"RecurrentGemma": "recurrentgemma",
38-
"Gemma2": "gemma2",
39-
"Gemma": "gemma",
40-
"phi3small": "phi3small",
41-
"phi3": "phi3",
42-
"PhiMoEForCausalLM": "phi3",
43-
"phi": "phi",
44-
"TLGv4ForCausalLM": "phi",
45-
"MixtralForCausalLM": "llama",
46-
"ArcticForCausalLM": "llama",
47-
"StarCoder": "gpt",
48-
"Dbrx": "dbrx",
49-
"T5": "t5",
50-
"Bart": "bart",
51-
"GLM": "glm",
52-
"InternLM2ForCausalLM": "internlm",
53-
"ExaoneForCausalLM": "exaone",
54-
"Nemotron": "gpt",
55-
"Deepseek": "deepseek",
56-
}
57-
58-
59-
def get_model_type(model):
60-
for k, v in MODEL_NAME_PATTERN_MAP.items():
61-
if k.lower() in type(model).__name__.lower():
62-
return v
63-
return None
64-
6524

6625
def get_mode_type_from_engine_dir(engine_dir_str):
6726
# Split the path by '/' and get the last part
@@ -106,20 +65,36 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs):
10665
return tokenizer
10766

10867

109-
def get_processor(ckpt_path, device=None, trust_remote_code=False):
68+
def get_processor(ckpt_path, model_type, device=None, trust_remote_code=False):
11069
"""
11170
Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
11271
"""
113-
processor = AutoProcessor.from_pretrained(
114-
ckpt_path,
115-
padding_side="left",
116-
trust_remote_code=trust_remote_code,
117-
)
118-
if processor.tokenizer.pad_token is None:
119-
processor.tokenizer.pad_token = processor.tokenizer.eos_token
120-
assert processor.tokenizer.pad_token is not None, f"Pad token for {ckpt_path} cannot be set!"
72+
if model_type == "whisper":
73+
processor = AutoProcessor.from_pretrained(
74+
ckpt_path,
75+
padding_side="left",
76+
trust_remote_code=trust_remote_code,
77+
)
78+
if processor.tokenizer.pad_token is None:
79+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
80+
assert processor.tokenizer.pad_token is not None, (
81+
f"Pad token for {ckpt_path} cannot be set!"
82+
)
12183

122-
return MllamaImageProcessor(processor, device)
84+
return processor
85+
elif model_type == "mllama":
86+
processor = AutoProcessor.from_pretrained(
87+
ckpt_path,
88+
padding_side="left",
89+
trust_remote_code=trust_remote_code,
90+
)
91+
if processor.tokenizer.pad_token is None:
92+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
93+
assert processor.tokenizer.pad_token is not None, (
94+
f"Pad token for {ckpt_path} cannot be set!"
95+
)
96+
97+
return MllamaImageProcessor(processor, device)
12398

12499

125100
def get_dtype(dtype):
@@ -179,6 +154,12 @@ def get_model(ckpt_path, device="cuda", gpu_mem_percentage=0.8, trust_remote_cod
179154
model = AutoModelForSeq2SeqLM.from_pretrained(
180155
ckpt_path, device_map=None, **model_kwargs
181156
).to(device)
157+
elif hf_config.model_type == "whisper":
158+
from transformers import WhisperForConditionalGeneration
159+
160+
model = WhisperForConditionalGeneration.from_pretrained(
161+
ckpt_path, device_map=device_map, **model_kwargs
162+
)
182163
elif hf_config.model_type == "glm":
183164
from transformers import AutoModelForSeq2SeqLM
184165

@@ -246,4 +227,4 @@ def is_model_on_gpu(model) -> bool:
246227

247228
def is_enc_dec(model_type) -> bool:
248229
"""Return if the model is a encoder-decoder model."""
249-
return model_type in ["t5", "bart"]
230+
return model_type in ["t5", "bart", "whisper"]

examples/llm_ptq/hf_ptq.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
import torch
2424
from example_utils import get_model, get_processor, get_tokenizer, is_enc_dec, is_model_on_gpu
25-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
25+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, WhisperProcessor
2626

2727
import modelopt.torch.opt as mto
2828
import modelopt.torch.quantization as mtq
@@ -39,6 +39,7 @@
3939
)
4040
from modelopt.torch.utils.image_processor import MllamaImageProcessor
4141
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
42+
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
4243
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
4344

4445
RAND_SEED = 1234
@@ -210,7 +211,19 @@ def main(args):
210211
elif args.dataset != "scienceqa":
211212
raise ValueError("Only the scienceqa dataset is supported for the mllama model.")
212213
processor = get_processor(
213-
args.pyt_ckpt_path, device, trust_remote_code=args.trust_remote_code
214+
args.pyt_ckpt_path, model_type, device, trust_remote_code=args.trust_remote_code
215+
)
216+
elif model_type == "whisper":
217+
if args.dataset is None:
218+
args.dataset = "peoples_speech"
219+
warnings.warn(
220+
"Currently only the peoples_speech dataset is supported for the whisper model. "
221+
"Overriding dataset to peoples_speech."
222+
)
223+
elif args.dataset != "peoples_speech":
224+
raise ValueError("Only the peoples_speech dataset is supported for the whisper model.")
225+
processor = get_processor(
226+
args.pyt_ckpt_path, model_type, device, trust_remote_code=args.trust_remote_code
214227
)
215228
else:
216229
if args.dataset is None:
@@ -273,8 +286,25 @@ def main(args):
273286
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
274287
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
275288
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
289+
# Whisper model expects mel-spectrogram input features of length 3000
290+
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
291+
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
292+
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
293+
if model_type == "whisper":
294+
max_sample_length = 3000
295+
num_mel_bins = model.config.num_mel_bins
296+
sample_input_single_batch = (
297+
torch.ones([1, num_mel_bins, max_sample_length], dtype=torch.float32).to(
298+
model.device
299+
)
300+
* 100
301+
)
302+
else:
303+
sample_input_single_batch = None
276304
args.batch_size = get_max_batch_size(
277-
model, sample_memory_usage_ratio=sample_memory_usage_ratio
305+
model,
306+
sample_memory_usage_ratio=sample_memory_usage_ratio,
307+
sample_input_single_batch=sample_input_single_batch,
278308
)
279309
if args.batch_size > args.calib_size:
280310
args.batch_size = args.calib_size
@@ -292,6 +322,17 @@ def main(args):
292322
batch_size=args.batch_size,
293323
num_samples=args.calib_size,
294324
)
325+
elif model_type == "whisper":
326+
assert processor is not None and isinstance(processor, WhisperProcessor), (
327+
"The AutoProcessor must be set."
328+
)
329+
calib_dataloader, first_text = get_speech_dataset_dataloader(
330+
dataset_name=args.dataset,
331+
processor=processor,
332+
batch_size=args.batch_size,
333+
num_samples=args.calib_size,
334+
device=device,
335+
)
295336
else:
296337
assert tokenizer is not None and isinstance(
297338
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
@@ -347,30 +388,40 @@ def main(args):
347388
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
348389

349390
# Only run single sample for preview
350-
input_ids = next(iter(calib_dataloader))["input_ids"][0:1]
351-
generated_ids_before_ptq = model.generate(input_ids, max_new_tokens=100)
391+
input_ids = next(iter(calib_dataloader))[
392+
"input_features" if model_type == "whisper" else "input_ids"
393+
][0:1]
394+
with torch.autocast("cuda"):
395+
generated_ids_before_ptq = model.generate(input_ids, max_new_tokens=100)
352396

353-
model = quantize_model(model, quant_cfg, args, calib_dataloader)
354-
if args.compress:
355-
mtq.compress(model)
356-
# Lets print the quantization summary
357-
if args.verbose:
358-
mtq.print_quant_summary(model)
397+
model = quantize_model(model, quant_cfg, args, calib_dataloader)
398+
if args.compress:
399+
mtq.compress(model)
400+
# Lets print the quantization summary
401+
if args.verbose:
402+
mtq.print_quant_summary(model)
359403

360-
# Run some samples
361-
generated_ids_after_ptq = model.generate(input_ids, max_new_tokens=100)
404+
# Run some samples
405+
generated_ids_after_ptq = model.generate(input_ids, max_new_tokens=100)
362406

363407
def input_decode(input_ids):
364408
if processor is not None and isinstance(processor, MllamaImageProcessor):
365409
return processor.tokenizer.batch_decode(input_ids)
410+
elif processor is not None and isinstance(processor, WhisperProcessor):
411+
return first_text
366412
elif tokenizer is not None:
367413
return tokenizer.batch_decode(input_ids)
368414
else:
369415
raise ValueError("The processor or tokenizer must be set")
370416

371417
def output_decode(generated_ids, input_shape):
372-
if tokenizer is not None and is_enc_dec(model_type):
373-
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
418+
if is_enc_dec(model_type):
419+
if processor is not None and isinstance(processor, WhisperProcessor):
420+
return processor.tokenizer.batch_decode(
421+
generated_ids, skip_special_tokens=True
422+
)[0]
423+
elif tokenizer is not None:
424+
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
374425
elif processor is not None and isinstance(processor, MllamaImageProcessor):
375426
return processor.tokenizer.batch_decode(generated_ids[:, input_shape:])
376427
elif tokenizer is not None:

modelopt/torch/export/layer_utils.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ def is_linear(module: nn.Module) -> bool:
221221
return any([k in type(module).__name__ for k in ["Linear", "Conv1D", "NormHead"]])
222222

223223

224+
def is_conv(module: nn.Module) -> bool:
225+
"""Returns whether the module is a convolutional layer."""
226+
return "Conv" in type(module).__name__
227+
228+
224229
def is_embedding(module: nn.Module) -> bool:
225230
"""Returns whether the module is an embedding layer."""
226231
module_type_name = type(module).__name__
@@ -644,6 +649,14 @@ def build_attention_config(
644649
assert k
645650
assert v
646651
qkv_modules = [q, k, v]
652+
for layer in qkv_modules:
653+
# Add the missing zero bias for Whisper model for export purpose
654+
if layer.bias is None and q.bias is not None:
655+
layer.bias = torch.nn.Parameter(
656+
torch.zeros(layer.weight.size(1), device=layer.weight.device),
657+
requires_grad=True,
658+
)
659+
print("Add missing zero bias for qkv modules for export purpose")
647660

648661
config.qkv = build_qkv(qkv_modules, model_metadata_config, ext_config, tp_size=tp_size)
649662

@@ -723,7 +736,7 @@ def _split_gate_from_fc(decoder_type, module, fc_name, fc_layer):
723736
"dense_h_to_4h", # falcon, chatglm, bloom
724737
"linear_fc1",
725738
"w2", # qwen
726-
"fc1", # phi, gemma
739+
"fc1", # phi, gemma, whisper
727740
"gate_up_proj", # phi
728741
"wi_0", # t5
729742
"wi", # t5
@@ -739,7 +752,7 @@ def _split_gate_from_fc(decoder_type, module, fc_name, fc_layer):
739752
"down_proj", # llama, baichuan, mpt, phi, recurrentgemma, nemotron, deepseek
740753
"linear_fc2",
741754
"proj",
742-
"fc2", # phi, gemma
755+
"fc2", # phi, gemma, whisper
743756
"wo", # t5
744757
]
745758
)
@@ -1288,21 +1301,29 @@ def build_decoder_config(
12881301
for layer in sub_module.children():
12891302
combined_module.append(layer)
12901303
module_layers = dict(combined_module.named_children())
1291-
elif decoder_type in ["bart"]:
1292-
# BartEncoderLayer, BartDecoderLayer have MLP component with no Module wrapper.
1293-
# creating a dummy module so that is_mlp may catch it.
1294-
bart_mlp_submodule_names = ["fc1", "fc2", "activation_fn"]
1304+
elif decoder_type in ["bart", "whisper"]:
1305+
if decoder_type == "whisper":
1306+
# Add max_position_embeddings for Whisper model
1307+
if model_metadata_config.get("enc_dec") == "enc":
1308+
config.max_position_embeddings = module.self_attn.config.max_source_positions
1309+
else:
1310+
config.max_position_embeddings = module.self_attn.config.max_target_positions
1311+
# BartEncoderLayer, BartDecoderLayer, WhisperEncoderLayer, WhisperDecoderLayer
1312+
# have MLP component with no Module wrapper.
1313+
# Create a dummy module so that is_mlp may catch it.
1314+
encdec_mlp_submodule_names = ["fc1", "fc2", "activation_fn", "activation_fn"]
12951315
module_layers = dict(module.named_children())
12961316

1297-
class BartMLP(nn.Module):
1317+
class EncDecMLP(nn.Module):
12981318
def __init__(self):
12991319
super().__init__()
13001320

1301-
bart_mlp_module = BartMLP()
1302-
for submodule_name in bart_mlp_submodule_names:
1303-
setattr(bart_mlp_module, submodule_name, getattr(module, submodule_name))
1304-
module_layers.pop(submodule_name)
1305-
module_layers.update({"MLP": bart_mlp_module})
1321+
encdec_mlp_module = EncDecMLP()
1322+
for submodule_name in encdec_mlp_submodule_names:
1323+
if submodule_name in module_layers:
1324+
setattr(encdec_mlp_module, submodule_name, getattr(module, submodule_name))
1325+
module_layers.pop(submodule_name)
1326+
module_layers.update({"MLP": encdec_mlp_module})
13061327
else:
13071328
module_layers = dict(module.named_children())
13081329
if decoder_type in ["exaone"]:
@@ -1612,13 +1633,13 @@ def get_experts_linear_names(model: torch.nn.Module):
16121633

16131634
def model_type_is_enc_dec(model_type):
16141635
"""Check if model_type is a enc-dec model."""
1615-
return model_type in ["t5", "bart"]
1636+
return model_type in ["t5", "bart", "whisper"]
16161637

16171638

16181639
def get_enc_dec_models(hf_model, model_type):
16191640
"""Get the correct encoder, decoder from hf model."""
16201641
assert model_type_is_enc_dec(model_type), "This encoder decoder model is not supported"
1621-
if model_type in "bart":
1642+
if model_type in ["bart", "whisper"]:
16221643
return [("enc", hf_model.model.encoder), ("dec", hf_model.model.decoder)]
16231644
else:
16241645
return [("enc", hf_model.encoder), ("dec", hf_model.decoder)]

modelopt/torch/export/model_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,10 @@ class ModelConfig:
568568
bos_token_id: int = None
569569
pad_token_id: int = None
570570

571+
# For whisper encoder feature extractor
572+
conv1: ConvConfig = None
573+
conv2: ConvConfig = None
574+
571575
@property
572576
def vocab_size_padded(self):
573577
"""Returns the padded vocab_size of the model rounds to the tensor_parallel."""

0 commit comments

Comments
 (0)