You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FIX Multiple issues with low precision base models
As a user, it should be possible to manually cast the base model to a
lower precision dtype, float16 or bfloat16, and still have the different
PEFT methods work correctly. Currently, this is not the case for many
PEFT methods, as can be replicated by the added tests.
To understand the problem, it helps to take a step back. By default,
PEFT will treat the adapter weights with high precision, i.e. with
float32. When the base model is lower precision, the user needs to pass
inputs in lower precision too, as otherwise self.base_layer(x) would
fail. However, this low precision input clashes with the high precision
adapter weights.
The solution implemented in this PR is to cast the input to a higher
dtype [1]. That way, the whole adapter operation is conducted in high
precision. Only once that has finished will the final result be cast to
the original dtype. This should lead to better results, but it may
require more memory. Note that this is how LoRA is implemented, so the
changes in this PR bring the other methods more in line with what LoRA
does.
If the user does not want the adapter to be in float32, they can always
pass autocast_adapter_dtype=False when calling get_peft_model or
PeftModel.from_pretrained. This is also tested.
Besides adjusting the forward method to account for these changes, the
merge and unmerge methods also often had to be adjusted, as they did not
correctly account for the base model dtype. Now, those methods should
always conserve the original dtype of the base model.
Note that if, for whatever reason, the input casting in [1] is not
desired, users can use the disable_input_dtype_casting context manager
to disable it (more context information on this feature can be found in
PR huggingface#2353). I updated the corresponding code to be agnostic to the
specific PEFT method (beforehand, it was only for LoRA).
0 commit comments