Skip to content

Commit 2582589

Browse files
Fix performance issues
1 parent f1c0ef3 commit 2582589

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

liteASR/run_eval.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,12 @@
1212
torch.set_float32_matmul_precision('high')
1313

1414
def main(args):
15-
model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True).to(args.device)
16-
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
15+
model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, force_download=True).to(args.device)
16+
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3-turbo", force_download=True)
1717
model_input_name = processor.model_input_names[0]
1818

1919
if model.can_generate():
20-
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
21-
# for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
22-
# print(model.generation_config)
23-
# if getattr(model.generation_config, "is_multilingual"):
24-
# gen_kwargs["language"] = "en"
25-
# gen_kwargs["task"] = "transcribe"
20+
gen_kwargs = {"max_new_tokens": 224}
2621
elif args.max_new_tokens:
2722
raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model.")
2823

@@ -63,13 +58,14 @@ def benchmark(batch, min_new_tokens=None):
6358
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt", device=args.device)
6459

6560
inputs = inputs.to(args.device)
66-
inputs[model_input_name] = inputs[model_input_name].to(torch.bfloat16)
61+
inputs[model_input_name] = inputs[model_input_name].to(torch.float16)
6762

6863
# 2. Model Inference
6964
with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION):
65+
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
7066
if model.can_generate():
7167
# 2.1 Auto-regressive generation for encoder-decoder models
72-
pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens)
68+
pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens, forced_decoder_ids=forced_decoder_ids)
7369
else:
7470
# 2.2. Single forward pass for CTC
7571
with torch.no_grad():

0 commit comments

Comments
 (0)