Skip to content

Commit 1c853ea

Browse files
Fix trainable tokens with fsdp (#2681)
When using FSDP with trainable tokens, there was an error when retrieving the state_dict of the TrainableTokensWrapper. The reason is that for the state_dict that is passed to get_peft_model_state_dict, the FSDP wrapper was already unwrapped, which means the keys don't have the FSDP-specific prefix. However, in the PEFT code, when looking up keys from said state_dict, the prefix was not removed. Now it is removed, making the lookup succeed. The same logic applies to set_peft_model_state_dict. I could successfully start training with FSDP and trainable tokens locally by adjusting the examples/sft script to include trainable tokens. Checkpoints could be successfully created and resumed from. The only change I needed to make was to configure use_orig_params=True for FSDP.
1 parent c11a9df commit 1c853ea

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/peft/tuners/lora/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ class LoraConfig(PeftConfig):
283283
Either you specify a list of indices which will then target the model's input embedding layer (or, if not
284284
found, `embed_tokens`). Alternatively, you can specify a dictionary where the key is the name of the
285285
embedding module and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note
286-
that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled.
286+
that training with FSDP requires `use_orig_params=True` to avoid issues with non-uniform `requires_grad`.
287287
loftq_config (`Optional[LoftQConfig]`):
288288
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
289289
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
@@ -465,9 +465,8 @@ class LoraConfig(PeftConfig):
465465
"in two ways. Either you specify a list of indices which will then target the model's input embedding "
466466
"layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key "
467467
"is the name of the embedding module and the values are the list of token indices, e.g. "
468-
"`{'embed_tokens': [0, 1, ...]}`. "
469-
"Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. "
470-
"Also note that models using weight-tying are currently not supported."
468+
"`{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP requires `use_orig_params=True` to "
469+
"avoid issues with non-uniform `requires_grad`."
471470
)
472471
},
473472
)

src/peft/utils/save_and_load.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def renamed_dora_weights(k):
219219
# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
220220
for name, module in model.named_modules():
221221
if isinstance(module, AuxiliaryTrainingWrapper):
222+
if name.startswith("_fsdp_wrapped_module."):
223+
# If FSDP is used, the state_dict is from the unwrapped model, which will result in a key mismatch if we
224+
# don't remove the FSDP-specific prefix
225+
name = name.removeprefix("_fsdp_wrapped_module.")
222226
# Compute the module-relative state dict to make it easier for the adapter to fetch the appropriate
223227
# keys that the module thinks need to be saved. We cannot rely on `.state_dict()` internally of the
224228
# module since accelerators like DeepSpeed require special handling which is done for the model
@@ -381,6 +385,10 @@ def set_peft_model_state_dict(
381385
# `modules_to_save.{adapter_name}.` prefix. This prefix must be restored when loading the model from the
382386
# saved state dict which is why we fetch a load key map from the wrapper.
383387
key_map = module.adapter_state_dict_load_map(adapter_name)
388+
if name.startswith("_fsdp_wrapped_module."):
389+
# If FSDP is used, the state_dict is from the unwrapped model, which will result in a key mismatch if we
390+
# don't remove the FSDP-specific prefix
391+
name = name.removeprefix("_fsdp_wrapped_module.")
384392
for k in key_map:
385393
lookup_key = f"{name}.{k}"
386394
store_key = f"{name}.{key_map[k]}"

0 commit comments

Comments
 (0)