Hallucinations with 8bit Whisper PEFT model - solved with full / half precision #477
Replies: 2 comments 3 replies
-
Hey @sanchit-gandhi I have noticed that Inference Speed using QLoRA (4bits) is relatively slow too. My question is: if we have trained a model using the Combination of PEFT + (LoRA or QLoRA) shouldn't we have to load the model with the same Bits on Inference, or given that the adapters learn in 32fp/16fp we can use the inference in half-precision with no problem? and that means that we could train a full precision model + PEFT (using Accelerate for multiple GPUs) and used it with different types at inference time. |
Beta Was this translation helpful? Give feedback.
-
Hey @sanchit-gandhi I wondering if this is resolved if when finish training with 8Bit or 4Bit at the end the adapter is merged back to the model. (at leat QLoRA) peft_config = PeftConfig.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(model, output_dir)
model.eval()
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained("/opt/ml/model/")
# save tokenizer for easy inference
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
tokenizer.save_pretrained("/opt/ml/model/") |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Observed several instances where fine-tuning Whisper with PEFT and then running inference in 8bit precision gives ~5x slower inference speeds vs full precision, and increases Whisper’s propensity to hallucinate considerably
Table for inference speed with batch-size=1:

I'll include code snippets below, and update these in time to use a fine-tuned PEFT checkpoint with audio sample (currently these are both private):
Code to load PEFT model in 8bit then pass to pipeline:
Loading the model weights and PEFT weights in fp32/fp16 for inference drastically helps with inference time (faster than fp32), and retains the WER boost we get by fine-tuning with PEFT. There are almost no hallucinations when we run inference in full or half precision.
Code to load PEFT model in fp16 then pass to pipeline:
Takeaway: PEFT is great for stable, low-resource training in 8-bit. We can then leverage the fine-tuned checkpoints for fast inference in full or half precision and negate possible hallucinations
Beta Was this translation helpful? Give feedback.
All reactions