You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ENH: Adapter injection based on state_dict (#2637)
Make it possible to inject the PEFT adapters based on a state_dict
instead of the PEFT config.
See huggingface/diffusers#11874 for context.
Description
Right now, when creating a PEFT adapter like LoRA, the adapter layers
are injected based on the PEFT config, most notably the entries in
`target_modules`, but other arguments also play into this. Generally,
this is a good approach, but it breaks down in some situations. For
instance, in diffusers, we often have the situation that the checkpoint
was created without PEFT/diffusers, thus there is no PEFT config, only
the `state_dict`. To load these checkpoints in diffusers, the current
approach is to reverse-engineer a valid PEFT config based on the keys in
the `state_dict`.
Unfortunately, this is error prone. Moreover, not every combination of
`state_dict` keys can be easily expressed in a PEFT config through a
combination of `target_modules`, `exclude_modules`, etc. Yes, in theory
everything can be expressed by passing `target_module=<regex_pattern>`,
but reverse-engineering such a regex correctly and efficiently is very
hard (and thus currently not done).
This PR implements a completely different approach to inject adapters.
Instead of relying on the PEFT config to determine which layers to
target, it takes the `state_dict` directly as the source of truth. This
should allow to exactly match what is desired.
Implementation details
I took care to implement this change in a way that if no `state_dict` is
passed, the exact same code path as previously is taken. The risk of
breaking anything should thus be minimized.
Technically, it is not necessary to pass the `state_dict`, we are only
interested in the keys. I still called the argument `state_dict`, since
that is typically what we have at this point, but this can be easily
changed.
I thought it might be a good idea, if the `state_dict` is used, to still
check what modules would have been targeted if we had used the PEFT
config. Then, the results are compared and a warning is given if they
differ. This allows the user to see if the PEFT config is not correctly
specified. While running some diffusers tests, I never encountered this
warning, which is good. However, if we plan, for instance, to get rid of
all the reverse engineering of the PEFT config in diffusers, it would
make more sense to not give this warning.
Caveats
When the original LoRA model was using `target_parameters`, injecting
from `state_dict` will not work correctly. The problem is that the
`state_dict` looks the same, whether the module or a parameter was
targeted. Therefore, we cannot correctly determine the user's intent.
For now, what I decided to do is:
1. Always assume that `target_modules` is meant, as it's the far more
common occurrence.
2. When we detect `target_parameters` while using `state_dict` for
injection, we raise an error.
3. If we don't detect this, injection might just slip through, resulting
in modules being targeted (if they are valid modules) instead of
parameters.
4. Document that these two features don't work together.
I think overall, this is not too concerning, as both features are rather
niche and thus unlikely to be used in conjunction.
Related changes
While working on this PR, I made a couple of related, though not
strictly necessary, changes:
- Refactor tests in `test_low_level_api.py` to use pytest instead of
unittest
- Add default target modules for LoHa and LoKr (just copying LoRA)
- Most PEFT method's model classes like `LoraModel` had an `__init__`
that effectively just called `super()` with the same arguments. I
removed these `__init__` methods.
Copy file name to clipboardExpand all lines: docs/source/developer_guides/low_level_api.md
+23-1Lines changed: 23 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
16
16
17
17
# Adapter injection
18
18
19
-
With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. Currently, PEFT supports injecting [LoRA](../conceptual_guides/adapter#low-rank-adaptation-lora), [AdaLoRA](../conceptual_guides/adapter#adaptive-low-rank-adaptation-adalora), and [IA3](../conceptual_guides/ia3) into models because for these adapters, inplace modification of the model is sufficient for finetuning it.
19
+
With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. This works for all adapters except for those based on prompt learning (e.g. prefix tuning or p-tuning).
20
20
21
21
Check the table below to see when you should inject adapters.
22
22
@@ -87,6 +87,28 @@ DummyModel(
87
87
)
88
88
```
89
89
90
+
### Injection based on a `state_dict`
91
+
92
+
Sometimes, it is possible that there is a PEFT adapter checkpoint but the corresponding PEFT config is not known for whatever reason. To inject the PEFT layers for this checkpoint, you would usually have to reverse-engineer the corresponding PEFT config, most notably the `target_modules` argument, based on the `state_dict` from the checkpoint. This can be cumbersome and error prone. To avoid this, it is also possible to call [`inject_adapter_in_model`] and pass the loaded `state_dict` as an argument:
model = inject_adapter_in_model(lora_config, model, state_dict=state_dict)
101
+
```
102
+
103
+
In this case, PEFT will use the `state_dict` as reference for which layers to target instead of using the PEFT config. As a user, you don't have to set the exact `target_modules` of the PEFT config for this to work. However, you should still pass a PEFT config of the right type, in this example `LoraConfig`, you can leave the `target_modules` as `None`.
104
+
105
+
Be aware that this still only creates the uninitialized PEFT layers, the values from the `state_dict` are not used to populate the model weights. To populate the weights, proceed with calling [`set_peft_model_state_dict`] as described below.
106
+
107
+
⚠️ Note that if there is a mismatch between what is configured in the PEFT config and what is found in the `state_dict`, PEFT will warn you about this. You can ignore the warning if you know that the PEFT config is not correctly specified.
108
+
109
+
> [!WARNING]
110
+
> If the original PEFT adapters was using `target_parameters` instead of `target_modules`, injecting from a `state_dict` will not work correctly. In this case, it is mandatory to use the correct PEFT config for injection.
111
+
90
112
## Saving the model
91
113
92
114
To only save the adapter, use the [`get_peft_model_state_dict`] function:
0 commit comments