You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix: Multiple PEFT methods have issues with models loaded in float16 or bfloat16 (huggingface#2433)
As a user, it should be possible to manually cast the base model to a
lower precision dtype, float16 or bfloat16, and still have the different
PEFT methods work correctly. Currently, this is not the case for many
PEFT methods, as can be replicated by the added tests.
To understand the problem, it helps to take a step back. By default,
PEFT will treat the adapter weights with high precision, i.e. with
float32. When the base model is lower precision, the user needs to pass
inputs in lower precision too, as otherwise self.base_layer(x) would
fail. However, this low precision input clashes with the high precision
adapter weights.
The solution implemented in this PR is to cast the input to a higher
dtype [1]. That way, the whole adapter operation is conducted in high
precision. Only once that has finished will the final result be cast to
the original dtype. This should lead to better results, but it may
require more memory. Note that this is how LoRA is implemented, so the
changes in this PR bring the other methods more in line with what LoRA
does.
If the user does not want the adapter to be in float32, they can always
pass autocast_adapter_dtype=False when calling get_peft_model or
PeftModel.from_pretrained. This is also tested.
Besides adjusting the forward method to account for these changes, the
merge and unmerge methods also often had to be adjusted, as they did not
correctly account for the base model dtype. Now, those methods should
always conserve the original dtype of the base model.
Note that if, for whatever reason, the input casting in [1] is not
desired, users can use the disable_input_dtype_casting context manager
to disable it (more context information on this feature can be found in
PR huggingface#2353). I updated the corresponding code to be agnostic to the
specific PEFT method (beforehand, it was only for LoRA).
Note that model.merge_adapter(safe_merge=True) did not work so far, even
though the argument was documented it was not actually there. This is
now fixed.
## ValueError: Attempting to unscale FP16 gradients
42
+
## Dtype-related issues
43
+
44
+
### ValueError: Attempting to unscale FP16 gradients
43
45
44
46
This error probably occurred because the model was loaded with `torch_dtype=torch.float16` and then used in an automatic mixed precision (AMP) context, e.g. by setting `fp16=True` in the [`~transformers.Trainer`] class from 🤗 Transformers. The reason is that when using AMP, trainable weights should never use fp16. To make this work without loading the whole model in fp32, add the following to your code:
45
47
@@ -75,6 +77,23 @@ Starting from PEFT verion v0.12.0, PEFT automatically promotes the dtype of adap
75
77
76
78
</Tip>
77
79
80
+
### Selecting the dtype of the adapter
81
+
82
+
Most PEFT methods, like LoRA, work by adding trainable adapter weights. By default, those weights are stored in float32 dtype (fp32), i.e. at a relatively high precision. Therefore, even if the base model is loaded in float16 (fp16) or bfloat16 (bf16), the adapter weights are float32. When the adapter results are calculated during the forward pass, the input will typically be in the dtype of the base model, thus it will be upcast to float32 if necessary, then cast back to the original dtype.
83
+
84
+
If you prefer to have the adapter weights in the lower precision of the base model, i.e. in float16 or bfloat16, you can pass `autocast_adapter_dtype=False` when creating the model ([`~get_peft_model`]) or loading the model ([`~PeftModel.from_pretrained`]). There are some advantages and disadvantages to this:
85
+
86
+
Advantages of half precision adapter:
87
+
- computation slightly faster
88
+
- slightly less memory
89
+
- smaller file size of checkpoint (half the size)
90
+
91
+
Disadvantages of half precision adapter:
92
+
- slightly worse loss
93
+
- higher risk of overflow or underflow
94
+
95
+
Note that for most use cases, overall runtime and memory cost will be determined by the size of the base model and by the dataset, while the dtype of the PEFT adapter will only have a small impact.
96
+
78
97
## Bad results from a loaded PEFT model
79
98
80
99
There can be several reasons for getting a poor result from a loaded PEFT model which are listed below. If you're still unable to troubleshoot the problem, see if anyone else had a similar [issue](https://github.com/huggingface/peft/issues) on GitHub, and if you can't find any, open a new issue.
0 commit comments