|
12 | 12 | torch.set_float32_matmul_precision('high') |
13 | 13 |
|
14 | 14 | def main(args): |
15 | | - model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True).to(args.device) |
16 | | - processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") |
| 15 | + model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, force_download=True).to(args.device) |
| 16 | + processor = AutoProcessor.from_pretrained("openai/whisper-large-v3-turbo", force_download=True) |
17 | 17 | model_input_name = processor.model_input_names[0] |
18 | 18 |
|
19 | 19 | if model.can_generate(): |
20 | | - gen_kwargs = {"max_new_tokens": args.max_new_tokens} |
21 | | - # for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args |
22 | | - # print(model.generation_config) |
23 | | - # if getattr(model.generation_config, "is_multilingual"): |
24 | | - # gen_kwargs["language"] = "en" |
25 | | - # gen_kwargs["task"] = "transcribe" |
| 20 | + gen_kwargs = {"max_new_tokens": 224} |
26 | 21 | elif args.max_new_tokens: |
27 | 22 | raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model.") |
28 | 23 |
|
@@ -63,13 +58,14 @@ def benchmark(batch, min_new_tokens=None): |
63 | 58 | inputs = processor(audios, sampling_rate=16_000, return_tensors="pt", device=args.device) |
64 | 59 |
|
65 | 60 | inputs = inputs.to(args.device) |
66 | | - inputs[model_input_name] = inputs[model_input_name].to(torch.bfloat16) |
| 61 | + inputs[model_input_name] = inputs[model_input_name].to(torch.float16) |
67 | 62 |
|
68 | 63 | # 2. Model Inference |
69 | 64 | with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION): |
| 65 | + forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") |
70 | 66 | if model.can_generate(): |
71 | 67 | # 2.1 Auto-regressive generation for encoder-decoder models |
72 | | - pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens) |
| 68 | + pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens, forced_decoder_ids=forced_decoder_ids) |
73 | 69 | else: |
74 | 70 | # 2.2. Single forward pass for CTC |
75 | 71 | with torch.no_grad(): |
|
0 commit comments