Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jul 9, 2025

Make it possible to inject the PEFT adapters based on a state_dict instead of the PEFT config.

See huggingface/diffusers#11874 for context.

Description

Right now, when creating a PEFT adapter like LoRA, the adapter layers are injected based on the PEFT config, most notably the entries in target_modules, but other arguments also play into this. Generally, this is a good approach, but it breaks down in some situations. For instance, in diffusers, we often have the situation that the checkpoint was created without PEFT/diffusers, thus there is no PEFT config, only the state_dict. To load these checkpoints in diffusers, the current approach is to reverse-engineer a valid PEFT config based on the keys in the state_dict.

Unfortunately, this is error prone. Moreover, not every combination of state_dict keys can be easily expressed in a PEFT config through a combination of target_modules, exclude_modules, etc. Yes, in theory everything can be expressed by passing target_module=<regex_pattern>, but reverse-engineering such a regex correctly and efficiently is very hard (and thus currently not done).

This PR implements a completely different approach to inject adapters. Instead of relying on the PEFT config to determine which layers to target, it takes the state_dict directly as the source of truth. This should allow to exactly match what is desired.

Implementation details

I took care to implement this change in a way that if no state_dict is passed, the exact same code path as previously is taken. The risk of breaking anything should thus be minimized.

Technically, it is not necessary to pass the state_dict, we are only interested in the keys. I still called the argument state_dict, since that is typically what we have at this point, but this can be easily changed.

I thought it might be a good idea, if the state_dict is used, to still check what modules would have been targeted if we had used the PEFT config. Then, the results are compared and a warning is given if they differ. This allows the user to see if the PEFT config is not correctly specified. While running some diffusers tests, I never encountered this warning, which is good. However, if we plan, for instance, to get rid of all the reverse engineering of the PEFT config in diffusers, it would make more sense to not give this warning.

Caveats

When the original LoRA model was using target_parameters, injecting from state_dict will not work correctly. The problem is that the state_dict looks the same, whether the module or a parameter was targeted. Therefore, we cannot correctly determine the user's intent.

For now, what I decided to do is:

  1. Always assume that target_modules is meant, as it's the far more common occurrence.
  2. When we detect target_parameters while using state_dict for injection, we raise an error.
  3. If we don't detect this, injection might just slip through, resulting in modules being targeted (if they are valid modules) instead of parameters.
  4. Document that these two features don't work together.

I think overall, this is not too concerning, as both features are rather niche and thus unlikely to be used in conjunction.

Related changes

While working on this PR, I made a couple of related, though not strictly necessary, changes:

  • Refactor tests in test_low_level_api.py to use pytest instead of unittest
  • Add default target modules for LoHa and LoKr (just copying LoRA)
  • Most PEFT method's model classes like LoraModel had an __init__ that effectively just called super() with the same arguments. I removed these __init__ methods.

Testing with Diffusers

To test this PR right now, install this branch directly. Then, inside of diffusers, change this line:

https://github.com/huggingface/diffusers/blob/754fe85cace17ae8615d53578d6d842c9e4a1bd9/src/diffusers/loaders/peft.py#L321

- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs)

That's it. I tested the change with SD:

RUN_SLOW=1 pytest tests/lora/test_lora_layers_sd.py -v

and all tests passed. A new test based on the original issue of huggingface/diffusers#11874 should be added to confirm that this PR solves this difficult case.

Make it possible to inject the PEFT adapters based on a state_dict
instead of the PEFT config.

See huggingface/diffusers#11874 for context.

Description

Right now, when creating a PEFT adapter like LoRA, the adapter layers
are injected based on the PEFT config, most notably the entries in
target_modules, but other arguments also play into this. Generally, this
is a good approach, but it breaks down in some situations. For instance,
in diffusers, we often have the situation that the checkpoint was
created without PEFT/diffusers, thus there is no PEFT config, only the
state_dict. To load these checkpoints in diffusers, the current approach
is to reverse-engineer a valid PEFT config based on the keys in the
state_dict.

Unfortunately, this is error prone. Moreover, not every combination of
state_dict keys can be easily expressed in a PEFT config through a
combination of target_modules, exclude_modules, etc. Yes, in theory
everything can be expressed by passing target_module=regex_pattern, but
reverse engineering such a regex is correctly and efficiently is very
hard (and thus currently not done).

This PR implements a completely different approach to inject adapters.
Instead of relying on the PEFT config to determine which layers to
target, it takes the state_dict directly as the source of truth. This
should allow to exactly match what is desired.

Implementation details

Technically, it is not necessary to pass the state_dict, we are only
interested in the keys. I still called the argument state_dict, since
that is typically what we have at this point, but this can be easily
changed.

I thought it might be a good idea, if the state_dict is used, to still
check what modules would have been targeted if we had used the PEFT
config. Then, the results are compared and a warning is given if they
differ. This allows the user to see if the PEFT config is not correctly
specified. While running some diffusers tests, I never encountered this
situation, though.

However, if we plan, for instance, to get rid of all the reverse
engineering of the PEFT config in diffusers, it would make more sense to
not give this warning.

Progress

There is still a lot to do when it comes to finishing the PR, most
notably expanding it to all methods, not just LoRA, adding
documentation, and adding tests.

To test this PR right now, install this branch directly. Then, inside of
diffusers, change this line:

https://github.com/huggingface/diffusers/blob/754fe85cace17ae8615d53578d6d842c9e4a1bd9/src/diffusers/loaders/peft.py#L321

- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs)

That's it. I tested the change with SD:

RUN_SLOW=1 pytest tests/lora/test_lora_layers_sd.py -v

and all tests passed. A new test based on the original issue of
huggingface/diffusers#11874 should be added to
confirm that this PR solves this difficult case.
Comment on lines +545 to +558
result = self._check_target_module_exists(peft_config, key)
if isinstance(result, _ExcludedModule):
excluded_modules.append(key)
elif not result:
unmatched_modules.append(key)
else:
self.targeted_module_names.append(key)
parent, target, target_name = _get_submodules(model, key)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config, adapter_name, target, target_name, parent, current_key=key
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: This is the exact same code as before, just indented by one level.

@BenjaminBossan
Copy link
Member Author

FYI @sayakpaul it would be great if you could help with testing.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

I got the following error:

Traceback (most recent call last):
  File "/fsx/sayak/diffusers/check_flux_lora.py", line 15, in <module>
    pipeline.load_lora_weights("glif/l0w-r3z")
  File "/fsx/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 2198, in load_lora_weights
    self.load_lora_into_transformer(
  File "/fsx/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 2274, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/fsx/sayak/diffusers/src/diffusers/loaders/peft.py", line 321, in load_lora_adapter
    inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs)
  File "/fsx/sayak/peft/src/peft/mapping.py", line 82, in inject_adapter_in_model
    peft_model = tuner_cls(
  File "/fsx/sayak/peft/src/peft/tuners/lora/model.py", line 144, in __init__
    super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
  File "/fsx/sayak/peft/src/peft/tuners/tuners_utils.py", line 206, in __init__
    self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict)
  File "/fsx/sayak/peft/src/peft/tuners/tuners_utils.py", line 602, in inject_adapter
    1/0
ZeroDivisionError: division by zero

Code:

from diffusers import DiffusionPipeline
import torch 

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe_kwargs = {
    "prompt": "A heavy metal rock band made up of farm animals on stage at a concert",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}
pipeline.load_lora_weights("glif/l0w-r3z")
image = pipeline(**pipe_kwargs).images[0]
image.save("chicken.png")

@BenjaminBossan
Copy link
Member Author

Thx for testing. This is a good result, as this is exactly the bug you reported in huggingface/diffusers#11874, where layers are wrongly excluded because exclude_modules contains proj_out.

I removed the 1/0 and re-ran your example with 20 steps and 4bit bnb. This is what I got, does it look right?
output-foo

@sayakpaul
Copy link
Member

it does!

@BenjaminBossan
Copy link
Member Author

LMK if you have other examples I should test, including ones that already work, so that I can ensure that they keep on working. Also, what other diffusers tests should I run?

If you think this is a good direction to take, I will continue with the PR. For the time being, even if this works, I guess we keep the old approach in diffusers in case users have old PEFT versions installed. But do you think in the future, we can completely omit the reverse engineering of target_modules, exclude_modules etc. or are they needed for other purposes?

@sayakpaul
Copy link
Member

LMK if you have other examples I should test, including ones that already work, so that I can ensure that they keep on working. Also, what other diffusers tests should I run?

There's currently no diffusers test that would catch it as mentioned in huggingface/diffusers#11874. huggingface/diffusers#11874 (comment) is a good enough representation.

If you think this is a good direction to take, I will continue with the PR. For the time being, even if this works, I guess we keep the old approach in diffusers in case users have old PEFT versions installed.

We mandate peft==0.15.0 already:
https://github.com/huggingface/diffusers/blob/2d3d376bc00a11afb9e3c3e51a27bb2ef2f28b11/setup.py#L119

So, in this case, I won't mind lifting that up to the latest release that will contain this change. It will help us get rid of nasty issues anyway.

But do you think in the future, we can completely omit the reverse engineering of target_modules, exclude_modules etc. or are they needed for other purposes?

If we can derive it fully from the state_dict (i.e., PEFT takes care of it internally), then yes, why not! Happy to work with you to facilitate that from diffusers.

I will review the PR tomorrow before going for my time-off starting 11th July.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for continuing to give your attention to the diffusers community and fixing the nasty stuff. PR looks very reasonable.

existing_adapter_map[key] = module

# TODO: check if this the most robust way
state_dict_keys = {k.rsplit(".", 2)[0] for k in state_dict} if state_dict is not None else set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these state dict keys or module keys? I think it's the latter no? For example, foo.bar.weight. Here the module and submodule are foo and bar. If we were to just obtain state dict keys, a simple state_dict.keys() would have sufficed. So, I think we should consider renaming it.

with ctx():
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
# use the state_dict to match modules instead
if key not in state_dict_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case, state_dict is none these keys will be empty. Just flagging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check above if state_dict is None or not, so this is safe here.

Otherwise, users would get the error:

> ValueError: Model was initialised to not save vera_A and vera_B but
config now specifies to save projection! Set `config.save_projection` to
`False`.

which could be confusing.
They simply called super().__init__(*args, **kwargs) so there was no
point in having them.
@BenjaminBossan
Copy link
Member Author

@sayakpaul I updated the PR to be in a pretty mature state. I believe as a next step, we could combine it with the necessary changes in diffusers and do some testing to confirm that the new approach works. If for some reason it is not viable, I don't think we need to proceed with the PR, as I'm not aware of any need for this outside of diffusers.

But there is no hurry, we can work on this after your return.

@sayakpaul
Copy link
Member

Thanks, @BenjaminBossan. I will introduce minimal changes on the diffusers end as a PoC to test this PR. I don't see any reason not to graduate this PR into a merge, honestly.

@sayakpaul
Copy link
Member

kkb-code and others added 6 commits July 28, 2025 11:58
When using prompt learning methods, modules_to_save was not correctly
set automatically. This is really bad when using, for instance, sequence
classification tasks, which require the classifier layer to be added to
modules_to_save.

The issue was introduced in huggingface#2220 where it is wrongly assumed that the
PEFT config always has a modules_to_save attribute, which is not true
for prompt learning. In huggingface#2481, this was partly fixed by using getattr to
avoid an error. However, this did not resolve the fundamental issue that
for prompt learning, there is no such attribute, resulting in
module_to_save not being applied.

This PR proposes to fix this by adding modules_to_save to the prompt
learning configs.
Fixing the error:

permissions:
  contents: {}
 Check failure on line 11 in .github/workflows/deploy_method_comparison_app.yml

GitHub Actions
/ Deploy "method_comparison" Gradio to Spaces
Invalid workflow file

The workflow is not valid.
.github/workflows/deploy_method_comparison_app.yml (Line: 11, Col: 13):
A mapping was not expected
Normally, nn.Parameter cannot be targeted with LoRA adapters. This can
be problematic, e.g. when there are MoE layers that use nn.Parameter
directly, or when there is nn.Linear but the weight is passed directly
instead of calling forward (e.g. MHA).

It would be possible to craft a solution involving a special LoRA layer
for each of the modules that use nn.Parameter directly (e.g. lora.MHA)
but that doesn't scale. This PR is implements a direct way to target
nn.Parameter making use of torch.nn.utils.parametrize.

Using the feature requires passing target_parameters to the LoraConfig.
During the forward pass, when the parameter is acceessed, the LoRA
weights are added to the weights while still ensuring that gradients
flow correctly to the LoRA weights.

Right now, only LoRA supports this feature. Moreover, it is not possible
to target multiple parameters of the same module with the same adapter.
A workaround is to use multiple adapters (i.e. with different names).

---------

Co-authored-by: githubnemo <[email protected]>
Due to huggingface/transformers#38635, several
tests involving prefix tuning broke:

https://github.com/huggingface/peft/actions/runs/16417140904/job/46385751329

This PR fixes this by resoling two issues:

1. The _supports_cache_class attribute was removed, we can now assume
that it is True if the attribute does not exist.

2. We had special handling of past_key_values for GPTBigCodeForCausalLM
which is no longer required (nor valid) after that PR, so it is removed
depending on the transformers version.
yao-matrix and others added 7 commits July 28, 2025 11:58
…lerators like XPU (huggingface#2610)

make method comparision device agnostic, so it can expand to more
accelerators like XPU

---------

Signed-off-by: YAO Matrix <[email protected]>
…ggingface#2664)

* REFAC Update tokenizer parameter to processing_class in SFTTrainer instances across multiple examples

* REFAC Replace tokenizer parameter with processing_class in Trainer instances across documentation and examples

* Refactor tokenizer parameter to processing_class in various examples

- Updated the Trainer initialization in corda_finetuning.py to use processing_class instead of tokenizer.
- Changed the execution_count to null in image_classification_peft_lora.ipynb.
- Modified the tokenizer parameter to processing_class in image_classification_peft_lora.ipynb.
- Adjusted the tokenizer parameter to processing_class in peft_bnb_whisper_large_v2_training.ipynb.
- Updated the README.md in lorafa_finetune to reflect the change from tokenizer to processing_class in Trainer initialization.

* REFAC Update tokenizer parameter to processing_class in Seq2SeqTrainer instantiation

* REFAC Replace tokenizer parameter with processing_class in README and notebook examples
* Method Comparison: Improve formatting/layout of table

Quick improvement to reduce the dominance of columns like `{peft,train}_config` and make
numbers a bit more readable through proper decimal/thousands formatting.

* Bump gradio version to accomodate required fixes
When the target_parameters feature for LoRA was introduced in huggingface#2638,
there was one gap, namely the possibility to target multiple
nn.Parameters on the same module (there was only a workaround involving
multiple adapters, but that is not user friendly). With this PR, it is
now possible to achieve this.

The mechanism to enable this is a bit crude, namely allowing to nest
multiple ParamWrappers. This should generally be fine as long as there
are only a couple of nn.Parameters being targeted on the same module.
When there are dozens or hundreds, this approach could load to slow
downs or other issues.

A side effect of this implementation is that the ParamWrapper, when it
removes the parametrization, now only removes its own parametrization.
When using nn.utils.parametrize.remove_parametrization, it removes all
parametrizations, which is bad when we have nested parametrizations.

Alternative approaches

Some alternative approaches were discussed internally but the chosen one
was considered most practical.

Allow to have more than one adapted parameter per LoRA layer. This would
require to have nested dicts for the LoRA parameters, something like
self.lora_A[adapter_name][parameter_name]. We don't have this anywhere
so far and it would probably break implicit assumptions about PEFT
layers in many places (like, parsing of state_dict keys), requiring many
adjustments. Have an auxiliary module that contains the individual LoRA
layers that target the individual parameters. This could be the cleanest
solution and would probably be more efficient if there are a huge number
of targeted parameters per module. However, this also brings extra
complexity, as it requires implementing the logic of how to route the
information to the right parameter, and it may be a solution to a
problem that is irrelevant in practice (large number of targets per
module).
- Recommends trainable tokens as first measure
- Clarifies a few things about saving embeddings
- Adds full-finetuning as an option of last resort

---------

Co-authored-by: Benjamin Bossan <[email protected]>
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@BenjaminBossan BenjaminBossan marked this pull request as ready for review July 30, 2025 16:18
Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, a few questions/comments

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @githubnemo, your comments should be addressed.

While working on this, I noticed one flaw: When the original LoRA model was using target_parameters, injecting from state_dict will not work correctly. The problem is that the state_dict looks the same, whether the module or a parameter was targeted. Therefore, we cannot correctly determine the user's intent.

For now, what I decided to do is:

  1. Always assume that target_modules is meant, as it's the far more common occurrence.
  2. When we detect target_parameters while using state_dict for injection, we raise an error.
  3. If we don't detect this, injection might just slip through, resulting in modules being targeted (if they are valid modules) instead of parameters.
  4. Document that these two features don't work together.

I think overall, this is not too concerning, as both features are rather niche and thus unlikely to be used in conjunction.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! The way we deal with target parameters is fine :)

@BenjaminBossan BenjaminBossan merged commit 337be05 into huggingface:main Aug 1, 2025
11 of 14 checks passed
@BenjaminBossan BenjaminBossan deleted the enh-inject-adapter-based-on-state_dict branch August 1, 2025 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants