2222import numpy as np
2323import torch
2424from example_utils import get_model , get_processor , get_tokenizer , is_enc_dec , is_model_on_gpu
25- from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
25+ from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast , WhisperProcessor
2626
2727import modelopt .torch .opt as mto
2828import modelopt .torch .quantization as mtq
3939)
4040from modelopt .torch .utils .image_processor import MllamaImageProcessor
4141from modelopt .torch .utils .memory_monitor import launch_memory_monitor
42+ from modelopt .torch .utils .speech_dataset_utils import get_speech_dataset_dataloader
4243from modelopt .torch .utils .vlm_dataset_utils import get_vlm_dataset_dataloader
4344
4445RAND_SEED = 1234
@@ -210,7 +211,19 @@ def main(args):
210211 elif args .dataset != "scienceqa" :
211212 raise ValueError ("Only the scienceqa dataset is supported for the mllama model." )
212213 processor = get_processor (
213- args .pyt_ckpt_path , device , trust_remote_code = args .trust_remote_code
214+ args .pyt_ckpt_path , model_type , device , trust_remote_code = args .trust_remote_code
215+ )
216+ elif model_type == "whisper" :
217+ if args .dataset is None :
218+ args .dataset = "peoples_speech"
219+ warnings .warn (
220+ "Currently only the peoples_speech dataset is supported for the whisper model. "
221+ "Overriding dataset to peoples_speech."
222+ )
223+ elif args .dataset != "peoples_speech" :
224+ raise ValueError ("Only the peoples_speech dataset is supported for the whisper model." )
225+ processor = get_processor (
226+ args .pyt_ckpt_path , model_type , device , trust_remote_code = args .trust_remote_code
214227 )
215228 else :
216229 if args .dataset is None :
@@ -273,8 +286,25 @@ def main(args):
273286 # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
274287 # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
275288 sample_memory_usage_ratio = 2 if "awq" in args .qformat or "sq" in args .qformat else 1.1
289+ # Whisper model expects mel-spectrogram input features of length 3000
290+ # Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
291+ # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
292+ # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
293+ if model_type == "whisper" :
294+ max_sample_length = 3000
295+ num_mel_bins = model .config .num_mel_bins
296+ sample_input_single_batch = (
297+ torch .ones ([1 , num_mel_bins , max_sample_length ], dtype = torch .float32 ).to (
298+ model .device
299+ )
300+ * 100
301+ )
302+ else :
303+ sample_input_single_batch = None
276304 args .batch_size = get_max_batch_size (
277- model , sample_memory_usage_ratio = sample_memory_usage_ratio
305+ model ,
306+ sample_memory_usage_ratio = sample_memory_usage_ratio ,
307+ sample_input_single_batch = sample_input_single_batch ,
278308 )
279309 if args .batch_size > args .calib_size :
280310 args .batch_size = args .calib_size
@@ -292,6 +322,17 @@ def main(args):
292322 batch_size = args .batch_size ,
293323 num_samples = args .calib_size ,
294324 )
325+ elif model_type == "whisper" :
326+ assert processor is not None and isinstance (processor , WhisperProcessor ), (
327+ "The AutoProcessor must be set."
328+ )
329+ calib_dataloader , first_text = get_speech_dataset_dataloader (
330+ dataset_name = args .dataset ,
331+ processor = processor ,
332+ batch_size = args .batch_size ,
333+ num_samples = args .calib_size ,
334+ device = device ,
335+ )
295336 else :
296337 assert tokenizer is not None and isinstance (
297338 tokenizer , (PreTrainedTokenizer , PreTrainedTokenizerFast )
@@ -347,30 +388,40 @@ def main(args):
347388 quant_cfg ["algorithm" ] = {"method" : "smoothquant" , "alpha" : 0.5 }
348389
349390 # Only run single sample for preview
350- input_ids = next (iter (calib_dataloader ))["input_ids" ][0 :1 ]
351- generated_ids_before_ptq = model .generate (input_ids , max_new_tokens = 100 )
391+ input_ids = next (iter (calib_dataloader ))[
392+ "input_features" if model_type == "whisper" else "input_ids"
393+ ][0 :1 ]
394+ with torch .autocast ("cuda" ):
395+ generated_ids_before_ptq = model .generate (input_ids , max_new_tokens = 100 )
352396
353- model = quantize_model (model , quant_cfg , args , calib_dataloader )
354- if args .compress :
355- mtq .compress (model )
356- # Lets print the quantization summary
357- if args .verbose :
358- mtq .print_quant_summary (model )
397+ model = quantize_model (model , quant_cfg , args , calib_dataloader )
398+ if args .compress :
399+ mtq .compress (model )
400+ # Lets print the quantization summary
401+ if args .verbose :
402+ mtq .print_quant_summary (model )
359403
360- # Run some samples
361- generated_ids_after_ptq = model .generate (input_ids , max_new_tokens = 100 )
404+ # Run some samples
405+ generated_ids_after_ptq = model .generate (input_ids , max_new_tokens = 100 )
362406
363407 def input_decode (input_ids ):
364408 if processor is not None and isinstance (processor , MllamaImageProcessor ):
365409 return processor .tokenizer .batch_decode (input_ids )
410+ elif processor is not None and isinstance (processor , WhisperProcessor ):
411+ return first_text
366412 elif tokenizer is not None :
367413 return tokenizer .batch_decode (input_ids )
368414 else :
369415 raise ValueError ("The processor or tokenizer must be set" )
370416
371417 def output_decode (generated_ids , input_shape ):
372- if tokenizer is not None and is_enc_dec (model_type ):
373- return tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
418+ if is_enc_dec (model_type ):
419+ if processor is not None and isinstance (processor , WhisperProcessor ):
420+ return processor .tokenizer .batch_decode (
421+ generated_ids , skip_special_tokens = True
422+ )[0 ]
423+ elif tokenizer is not None :
424+ return tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
374425 elif processor is not None and isinstance (processor , MllamaImageProcessor ):
375426 return processor .tokenizer .batch_decode (generated_ids [:, input_shape :])
376427 elif tokenizer is not None :
0 commit comments