Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Mar 18, 2025

As a user, it should be possible to manually cast the base model to a lower precision dtype, float16 or bfloat16, and still have the different PEFT methods work correctly. Currently, this is not the case for many PEFT methods, as can be replicated by the added tests.

To understand the problem, it helps to take a step back. By default, PEFT will treat the adapter weights with high precision, i.e. with float32. When the base model is lower precision, the user needs to pass inputs in lower precision too, as otherwise self.base_layer(x) would fail. However, this low precision input clashes with the high precision adapter weights.

The solution implemented in this PR is to cast the input to a higher dtype [1]. That way, the whole adapter operation is conducted in high precision. Only once that has finished will the final result be cast to the original dtype if necessary. This should lead to better results thanks to the high precision, but it may require more memory. Note that this is how LoRA is implemented, so the changes in this PR bring the other methods more in line with what LoRA does.

If the user does not want the adapter to be in float32, they can always pass autocast_adapter_dtype=False when calling get_peft_model or PeftModel.from_pretrained etc. This is also tested.

Besides adjusting the forward method to account for these changes, the merge and unmerge methods also often had to be adjusted, as they did not correctly account for the base model dtype. Now, those methods should always conserve the original dtype of the base model.

Note that if, for whatever reason, the input casting in [1] is not desired, users can use the disable_input_dtype_casting context manager to disable it (more context information on this feature can be found in PR #2353). I updated the corresponding code to be agnostic to the specific PEFT method (beforehand, it was only for LoRA). This is why I had to move the _cast_input_dtype method from LoRA to the BaseTunerLayer.

An independent bug in VeRA was discovered through the new tests, where the weights for vera_lambda_d and vera_lambda_b were re-assigned after merging in an incorrect shape. This is now also fixed.

EDIT

While working on this, I made a few more changes:

  • model.merge_adapter(safe_merge=...) was not supported, even though the safe_merge argument was documented. I added the argument now.
  • LN Tuning had no support for safe_merge, this is now added (but it's a no-op, as there is no merging, just swapping parameters)
  • For some reason, CI failed when using the lewtun/tiny-random-OPTForCausalLM-delta model, I now moved to peft-internal-testing/tiny-opt-lora-revision for those tests. This is better anyway, since we want to rely less on models from private accounts.

As a user, it should be possible to manually cast the base model to a
lower precision dtype, float16 or bfloat16, and still have the different
PEFT methods work correctly. Currently, this is not the case for many
PEFT methods, as can be replicated by the added tests.

To understand the problem, it helps to take a step back. By default,
PEFT will treat the adapter weights with high precision, i.e. with
float32. When the base model is lower precision, the user needs to pass
inputs in lower precision too, as otherwise self.base_layer(x) would
fail. However, this low precision input clashes with the high precision
adapter weights.

The solution implemented in this PR is to cast the input to a higher
dtype [1]. That way, the whole adapter operation is conducted in high
precision. Only once that has finished will the final result be cast to
the original dtype. This should lead to better results, but it may
require more memory. Note that this is how LoRA is implemented, so the
changes in this PR bring the other methods more in line with what LoRA
does.

If the user does not want the adapter to be in float32, they can always
pass autocast_adapter_dtype=False when calling get_peft_model or
PeftModel.from_pretrained. This is also tested.

Besides adjusting the forward method to account for these changes, the
merge and unmerge methods also often had to be adjusted, as they did not
correctly account for the base model dtype. Now, those methods should
always conserve the original dtype of the base model.

Note that if, for whatever reason, the input casting in [1] is not
desired, users can use the disable_input_dtype_casting context manager
to disable it (more context information on this feature can be found in
PR huggingface#2353). I updated the corresponding code to be agnostic to the
specific PEFT method (beforehand, it was only for LoRA).
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Several more or less random ops are not implemented for MacOS.
There is probably no reason to have this code at all.
@BenjaminBossan
Copy link
Member Author

Pinging the contributors of the different methods if they could be so kind as to review the changes that affect them:

If I forgot anyone, please feel free to ping them.

If there is general agreement, my plan is to merge this shortly after the next PEFT release (v0.15.0), so aiming for next week. That way, if this does introduce regressions, we have more time to fix them.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review March 18, 2025 17:14
@BenjaminBossan BenjaminBossan changed the title Fix multiple methods model dtype issues Fix: Multiple PEFT methods have issues with models loaded in float16 or bfloat16 Mar 18, 2025
@yaswanth19
Copy link
Contributor

yaswanth19 commented Mar 20, 2025

@BenjaminBossan AFAIK the changes seems okay for LoKr 🤗. Also is precision loss the reason to convert the inputs to float32 rather than downcasting the adapter weights 🤔

@nbasyl
Copy link

nbasyl commented Mar 20, 2025

Hi @BenjaminBossan, I recently found an issue with DoRA training where the magnitude vector isn’t being updated when training on A100, although it works fine on H100. The torchtune team also found a similar issue. Could this be related to the casting issue? Also could you point me to any test cases to verify this pull request? Additionally, I recently ran into a problem using DoRA with DeepSpeed ZeRO-3. I’m not sure if this is related as well, but I’ll open an issue soon with steps to reproduce the error. Thanks a lot!

@BenjaminBossan
Copy link
Member Author

AFAIK the changes seems to okay for LoKr

Thanks for checking.

Also is precision loss the reason to convert the inputs to float32 rather than downcasting the adapter weights

Exactly, although I'm not sure how much of a difference it really makes. It's just been like that for LoRA from the start and I think it makes sense to consolidate the other methods.

I recently found an issue with DoRA training where the magnitude vector isn’t being updated when training on A100, although it works fine on H100

I haven't worked with those GPUs, so I can't say if there could be a relationship.

The torchtune team also found a similar issue

Is it this one: meta-pytorch/torchtune#2250? If yes, and if it's indeed an underflow issue, then this PR could indeed fix it.

Also could you point me to any test cases to verify this pull request?

The tests I added are here: https://github.com/huggingface/peft/pull/2433/files#diff-8e23036752ec7b8a68889069e72c29171c53652df5e7d390afa4d223f22c77d2

You should be able to check out my branch and then call

pytest tests/test_custom_models.py -k "(test_forward_bfloat16 or test_forward_float16) and dora" -v

to check them. Note, however, that those are very simple tests, no multi GPU etc.

Additionally, I recently ran into a problem using DoRA with DeepSpeed ZeRO-3. I’m not sure if this is related as well, but I’ll open an issue soon with steps to reproduce the error.

Okay, feel free to ping me once the issue is done.

@nbasyl
Copy link

nbasyl commented Mar 24, 2025

Hi @BenjaminBossan, I have ran the test on DoRA using the test file you wrote on H100, and the following cases did not pass:
image

I also wanted to verify if the magnitude vector is being updated correctly, so I tried modifying the function def test_training_works:,
image
I printed out the before and after lora_A weight and got NaN
image.

Am I missing something when running pytest tests/test_custom_models.py -k "test_training_works" -v -s, or could you help me write a simple test function to check whether the magnitude vector is being updated correctly? Thank you!

@BenjaminBossan
Copy link
Member Author

I have ran the test on DoRA using the test file you wrote on H100, and the following cases did not pass:

Hmm, strange, does the error only occur on H100 or was that the only device you tested? Could you please also show the full error message?

I also wanted to verify if the magnitude vector is being updated correctly, so I tried modifying the function def test_training_works:,

Please note that this a very specific test inside the TestDynamicDispatch class, it should not be used to test DoRA. I checked this test and it is indeed faulty (lr is too high, leading to nan), I will fix it in a separate PR. But this is totally independent of DoRA.

could you help me write a simple test function to check whether the magnitude vector is being updated correctly?

To check this, I think the test here and/or the tests below it are best suited:

def test_causal_lm_training_4bit_dora(self):

You can run it with:

pytest tests/test_gpu_examples.py -k test_causal_lm_training_4bit_dora

Just to do a quick check, I added these lines of code:

...
model = get_peft_model(model, config)
mv_before = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_magnitude_vector["default"].weight.data.clone()

...

trainer.train()
mv_after = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_magnitude_vector["default"].weight.data
assert not torch.allclose(mv_before, mv_after, atol=1e-5, rtol=1e-5)

WDYT?

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 24, 2025
This specific test used a learning rate that is too high, resulting in
nan weights. Then, when weights are compared to assert that they're
different, the test passes trivially because nan != nan. The lr is now
reduced and there is a sanity check that none of the weights contain
non-finite values.

See discussion in
huggingface#2433 (comment)
ff.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan The changes for (IA)^3 seem fine!

@Qubitium
Copy link
Contributor

Qubitium commented Mar 25, 2025

@BenjaminBossan I have a question regarding Lora.

After PR:

Model is bfloat16 and lora is stored as bfloat16. Peft would now upcast, by default the x input to float32 and the lora to float32 correct?

Before PR:

Peft did not do any up/down casting and used as is, by default.

Is above correct summation?

Also, can you do some prelim tests to see what the performance costs are for the autocast to float32 and lora adapter apply if any? If there is no underflow issue for some cases of adapter, it would be good to let user know in code/doc that autocast resolves the accuracy issue but at XX cost so users are a little bit more informed.

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Mar 25, 2025

After PR:

Model is bfloat16 and lora is stored as bfloat16. Peft would now upcast, by default the x input to float32 and the lora to float32 correct?

Before PR:

Peft did not do any up/down casting and used as is, by default.

Is above correct summation?

No, for LoRA there is no change in this PR (a few small changes around merging and MHA, but otherwise no change). The situation for LoRA has always been: By default, the LoRA weights are loaded in float32. If the input is lower precision, it is upcast to float32 when the LoRA part is calculated during forward, and after this downcast to the previous dtype. See these parts of the current code:

def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:

x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
result = result.to(torch_result_dtype)

This PR just makes it so that this behavior is consistent between the different layer types and PEFT methods. The reason why I worked on this is because I found many PEFT methods break when loading the model in lower dtype.

Also, can you do some prelim tests to see what the performance costs are for the autocast to float32 and lora adapter apply if any? If there is no underflow issue for some cases of adapter, it would be good to let user know in code/doc that autocast resolves the accuracy issue but at XX cost so users a little bit more informed.,

I did an experiment with meta-llama/Llama-3.2-3B (unquantized, bfloat16) on a 4090 with LoRA rank 32. What I got:

Metric LoRA float32 LoRA bfloat16
train time 1695 s 1400 s
train loss 0.6069 0.6205
test accuracy 0.4776 0.4685
memory max 21262MB 19782MB
memory reserved avg 11387MB 10789MB
memory reserved 99th perc. 16916MB 15828MB
checkpoint file size 35.0MB 17.5MB

(edit: updated values above after finishing full run)

So there are small but noticeable differences, with lower LoRA precision resulting in slightly faster training, less memory, and smaller file size, at the cost of slightly higher loss. Of course, results will vary a lot depending on the use case, and longer/bigger training runs could be more prone to under/overflow.

At the end of the day, users can easily change the behavior by passing autocast_adapter_dtype=False. Given that LoRA is the most commonly used PEFT method by far, I think it is better to make the LoRA behavior the default for the other PEFT methods than vice versa.

Regarding the documentation, yes, we can certainly highlight this option better. I'll check if I can find a good place in the docs to put this. Edit: Added some documentation.

@Joluck
Copy link
Contributor

Joluck commented Mar 26, 2025

@BenjaminBossan The structure of Bone is very simple, so it shouldn't have much impact. Do I only need to test it with pytest tests/ -k "bone"?

BenjaminBossan added a commit that referenced this pull request Mar 26, 2025
This specific test used a learning rate that is too high, resulting in
nan weights. Then, when weights are compared to assert that they're
different, the test passes trivially because nan != nan. The lr is now
reduced and there is a sanity check that none of the weights contain
non-finite values.

See discussion in
#2433 (comment)
ff.
@BenjaminBossan
Copy link
Member Author

The structure of Bone is very simple, so it shouldn't have much impact. Do I only need to test it with pytest tests/ -k "bone"?

No need for you to test, just a quick check if the code changes for Bone look right to you. We already run the tests and I also did a full training run with Bone and a bfloat16 base model and it worked.

@DTennant
Copy link
Contributor

Hi,

The changes for LN Tuning seems fine by me.
Thanks.

Note that model.merge_adapter(safe_merge=True) did not work so far, even
though the argument was documented it was not actually there. This is
now fixed.
For some reason, the model used in config tests caused issues, move to a
different one. This new one is from peft-internal, so that's better
anyway.
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor nitpicks

)

self.get_base_layer().weight.data = orig_weight
self.get_base_layer().weight.data = orig_weight.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.get_base_layer().weight.data = orig_weight.to(orig_dtype)
base_layer.weight.data = orig_weight.to(orig_dtype)

Comment on lines 172 to 179
self.base_layer.weight.data = orig_weight.to(orig_dtype)
else:
if self.bone_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data += delta_weight
self.base_layer.weight.data += delta_weight.to(orig_dtype)
else:
delta_weight = self.get_delta_weight_bone(active_adapter, self.base_layer.weight.data)
self.base_layer.weight.data = delta_weight
self.base_layer.weight.data = delta_weight.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use base_layer here as well, right?

@BenjaminBossan BenjaminBossan merged commit dfd82f7 into huggingface:main Apr 4, 2025
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-multiple-methods-model-dtype-issues branch April 4, 2025 10:06
@BenjaminBossan
Copy link
Member Author

Thanks everyone for your feedback. The PR is now merged, but the next PEFT release will not be very soon. So if you happen to find an issue, please report it so that we can fix it before the next release.

Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
This specific test used a learning rate that is too high, resulting in
nan weights. Then, when weights are compared to assert that they're
different, the test passes trivially because nan != nan. The lr is now
reduced and there is a sanity check that none of the weights contain
non-finite values.

See discussion in
huggingface#2433 (comment)
ff.
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
…or bfloat16 (huggingface#2433)

As a user, it should be possible to manually cast the base model to a
lower precision dtype, float16 or bfloat16, and still have the different
PEFT methods work correctly. Currently, this is not the case for many
PEFT methods, as can be replicated by the added tests.

To understand the problem, it helps to take a step back. By default,
PEFT will treat the adapter weights with high precision, i.e. with
float32. When the base model is lower precision, the user needs to pass
inputs in lower precision too, as otherwise self.base_layer(x) would
fail. However, this low precision input clashes with the high precision
adapter weights.

The solution implemented in this PR is to cast the input to a higher
dtype [1]. That way, the whole adapter operation is conducted in high
precision. Only once that has finished will the final result be cast to
the original dtype. This should lead to better results, but it may
require more memory. Note that this is how LoRA is implemented, so the
changes in this PR bring the other methods more in line with what LoRA
does.

If the user does not want the adapter to be in float32, they can always
pass autocast_adapter_dtype=False when calling get_peft_model or
PeftModel.from_pretrained. This is also tested.

Besides adjusting the forward method to account for these changes, the
merge and unmerge methods also often had to be adjusted, as they did not
correctly account for the base model dtype. Now, those methods should
always conserve the original dtype of the base model.

Note that if, for whatever reason, the input casting in [1] is not
desired, users can use the disable_input_dtype_casting context manager
to disable it (more context information on this feature can be found in
PR huggingface#2353). I updated the corresponding code to be agnostic to the
specific PEFT method (beforehand, it was only for LoRA).

Note that model.merge_adapter(safe_merge=True) did not work so far, even
though the argument was documented it was not actually there. This is
now fixed.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
This specific test used a learning rate that is too high, resulting in
nan weights. Then, when weights are compared to assert that they're
different, the test passes trivially because nan != nan. The lr is now
reduced and there is a sanity check that none of the weights contain
non-finite values.

See discussion in
huggingface#2433 (comment)
ff.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
…or bfloat16 (huggingface#2433)

As a user, it should be possible to manually cast the base model to a
lower precision dtype, float16 or bfloat16, and still have the different
PEFT methods work correctly. Currently, this is not the case for many
PEFT methods, as can be replicated by the added tests.

To understand the problem, it helps to take a step back. By default,
PEFT will treat the adapter weights with high precision, i.e. with
float32. When the base model is lower precision, the user needs to pass
inputs in lower precision too, as otherwise self.base_layer(x) would
fail. However, this low precision input clashes with the high precision
adapter weights.

The solution implemented in this PR is to cast the input to a higher
dtype [1]. That way, the whole adapter operation is conducted in high
precision. Only once that has finished will the final result be cast to
the original dtype. This should lead to better results, but it may
require more memory. Note that this is how LoRA is implemented, so the
changes in this PR bring the other methods more in line with what LoRA
does.

If the user does not want the adapter to be in float32, they can always
pass autocast_adapter_dtype=False when calling get_peft_model or
PeftModel.from_pretrained. This is also tested.

Besides adjusting the forward method to account for these changes, the
merge and unmerge methods also often had to be adjusted, as they did not
correctly account for the base model dtype. Now, those methods should
always conserve the original dtype of the base model.

Note that if, for whatever reason, the input casting in [1] is not
desired, users can use the disable_input_dtype_casting context manager
to disable it (more context information on this feature can be found in
PR huggingface#2353). I updated the corresponding code to be agnostic to the
specific PEFT method (beforehand, it was only for LoRA).

Note that model.merge_adapter(safe_merge=True) did not work so far, even
though the argument was documented it was not actually there. This is
now fixed.
cyyever pushed a commit to cyyever/peft that referenced this pull request Sep 4, 2025
csqaiub added a commit to csqaiub/peft that referenced this pull request Sep 28, 2025
This specific test used a learning rate that is too high, resulting in
nan weights. Then, when weights are compared to assert that they're
different, the test passes trivially because nan != nan. The lr is now
reduced and there is a sanity check that none of the weights contain
non-finite values.

See discussion in
huggingface/peft#2433 (comment)
ff.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants