22
22
import numpy as np
23
23
import torch
24
24
from example_utils import get_model , get_processor , get_tokenizer , is_enc_dec , is_model_on_gpu
25
- from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
25
+ from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast , WhisperProcessor
26
26
27
27
import modelopt .torch .opt as mto
28
28
import modelopt .torch .quantization as mtq
39
39
)
40
40
from modelopt .torch .utils .image_processor import MllamaImageProcessor
41
41
from modelopt .torch .utils .memory_monitor import launch_memory_monitor
42
+ from modelopt .torch .utils .speech_dataset_utils import get_speech_dataset_dataloader
42
43
from modelopt .torch .utils .vlm_dataset_utils import get_vlm_dataset_dataloader
43
44
44
45
RAND_SEED = 1234
@@ -210,7 +211,19 @@ def main(args):
210
211
elif args .dataset != "scienceqa" :
211
212
raise ValueError ("Only the scienceqa dataset is supported for the mllama model." )
212
213
processor = get_processor (
213
- args .pyt_ckpt_path , device , trust_remote_code = args .trust_remote_code
214
+ args .pyt_ckpt_path , model_type , device , trust_remote_code = args .trust_remote_code
215
+ )
216
+ elif model_type == "whisper" :
217
+ if args .dataset is None :
218
+ args .dataset = "peoples_speech"
219
+ warnings .warn (
220
+ "Currently only the peoples_speech dataset is supported for the whisper model. "
221
+ "Overriding dataset to peoples_speech."
222
+ )
223
+ elif args .dataset != "peoples_speech" :
224
+ raise ValueError ("Only the peoples_speech dataset is supported for the whisper model." )
225
+ processor = get_processor (
226
+ args .pyt_ckpt_path , model_type , device , trust_remote_code = args .trust_remote_code
214
227
)
215
228
else :
216
229
if args .dataset is None :
@@ -273,8 +286,25 @@ def main(args):
273
286
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
274
287
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
275
288
sample_memory_usage_ratio = 2 if "awq" in args .qformat or "sq" in args .qformat else 1.1
289
+ # Whisper model expects mel-spectrogram input features of length 3000
290
+ # Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
291
+ # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
292
+ # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
293
+ if model_type == "whisper" :
294
+ max_sample_length = 3000
295
+ num_mel_bins = model .config .num_mel_bins
296
+ sample_input_single_batch = (
297
+ torch .ones ([1 , num_mel_bins , max_sample_length ], dtype = torch .float32 ).to (
298
+ model .device
299
+ )
300
+ * 100
301
+ )
302
+ else :
303
+ sample_input_single_batch = None
276
304
args .batch_size = get_max_batch_size (
277
- model , sample_memory_usage_ratio = sample_memory_usage_ratio
305
+ model ,
306
+ sample_memory_usage_ratio = sample_memory_usage_ratio ,
307
+ sample_input_single_batch = sample_input_single_batch ,
278
308
)
279
309
if args .batch_size > args .calib_size :
280
310
args .batch_size = args .calib_size
@@ -292,6 +322,17 @@ def main(args):
292
322
batch_size = args .batch_size ,
293
323
num_samples = args .calib_size ,
294
324
)
325
+ elif model_type == "whisper" :
326
+ assert processor is not None and isinstance (processor , WhisperProcessor ), (
327
+ "The AutoProcessor must be set."
328
+ )
329
+ calib_dataloader , first_text = get_speech_dataset_dataloader (
330
+ dataset_name = args .dataset ,
331
+ processor = processor ,
332
+ batch_size = args .batch_size ,
333
+ num_samples = args .calib_size ,
334
+ device = device ,
335
+ )
295
336
else :
296
337
assert tokenizer is not None and isinstance (
297
338
tokenizer , (PreTrainedTokenizer , PreTrainedTokenizerFast )
@@ -347,30 +388,40 @@ def main(args):
347
388
quant_cfg ["algorithm" ] = {"method" : "smoothquant" , "alpha" : 0.5 }
348
389
349
390
# Only run single sample for preview
350
- input_ids = next (iter (calib_dataloader ))["input_ids" ][0 :1 ]
351
- generated_ids_before_ptq = model .generate (input_ids , max_new_tokens = 100 )
391
+ input_ids = next (iter (calib_dataloader ))[
392
+ "input_features" if model_type == "whisper" else "input_ids"
393
+ ][0 :1 ]
394
+ with torch .autocast ("cuda" ):
395
+ generated_ids_before_ptq = model .generate (input_ids , max_new_tokens = 100 )
352
396
353
- model = quantize_model (model , quant_cfg , args , calib_dataloader )
354
- if args .compress :
355
- mtq .compress (model )
356
- # Lets print the quantization summary
357
- if args .verbose :
358
- mtq .print_quant_summary (model )
397
+ model = quantize_model (model , quant_cfg , args , calib_dataloader )
398
+ if args .compress :
399
+ mtq .compress (model )
400
+ # Lets print the quantization summary
401
+ if args .verbose :
402
+ mtq .print_quant_summary (model )
359
403
360
- # Run some samples
361
- generated_ids_after_ptq = model .generate (input_ids , max_new_tokens = 100 )
404
+ # Run some samples
405
+ generated_ids_after_ptq = model .generate (input_ids , max_new_tokens = 100 )
362
406
363
407
def input_decode (input_ids ):
364
408
if processor is not None and isinstance (processor , MllamaImageProcessor ):
365
409
return processor .tokenizer .batch_decode (input_ids )
410
+ elif processor is not None and isinstance (processor , WhisperProcessor ):
411
+ return first_text
366
412
elif tokenizer is not None :
367
413
return tokenizer .batch_decode (input_ids )
368
414
else :
369
415
raise ValueError ("The processor or tokenizer must be set" )
370
416
371
417
def output_decode (generated_ids , input_shape ):
372
- if tokenizer is not None and is_enc_dec (model_type ):
373
- return tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
418
+ if is_enc_dec (model_type ):
419
+ if processor is not None and isinstance (processor , WhisperProcessor ):
420
+ return processor .tokenizer .batch_decode (
421
+ generated_ids , skip_special_tokens = True
422
+ )[0 ]
423
+ elif tokenizer is not None :
424
+ return tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
374
425
elif processor is not None and isinstance (processor , MllamaImageProcessor ):
375
426
return processor .tokenizer .batch_decode (generated_ids [:, input_shape :])
376
427
elif tokenizer is not None :
0 commit comments