Skip to content

Commit d0fa974

Browse files
authored
ENH Option to automatically tie modules_to_save (#2803)
Adds an option to the LoRA config, ensure_weight_tying, which, if enabled, ensures that if the embedding and LM head are tied, they share the ModulesToSaveWrapper. This ensures that their weights work correctly even after merging them.
1 parent 2813b9c commit d0fa974

File tree

5 files changed

+266
-5
lines changed

5 files changed

+266
-5
lines changed

src/peft/tuners/lora/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,17 @@ class LoraConfig(PeftConfig):
663663
arrow_config: Optional[ArrowConfig] = field(
664664
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
665665
)
666+
ensure_weight_tying: bool = field(
667+
default=False,
668+
metadata={
669+
"help": (
670+
"Whether to tie weights or not after peft initialization. "
671+
"This will ensure that the adapters added to the tied layers "
672+
"are also tied. This is only applicable for layers passed via "
673+
"`modules_to_save`."
674+
)
675+
},
676+
)
666677

667678
def to_dict(self):
668679
"""
@@ -681,6 +692,10 @@ def __post_init__(self):
681692
self.exclude_modules = (
682693
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
683694
)
695+
696+
if self.ensure_weight_tying:
697+
self.modules_to_tie = None
698+
684699
if isinstance(self.target_parameters, str):
685700
raise TypeError("`target_parameters` must be a list of strings or None.")
686701

src/peft/tuners/lora/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,3 +805,9 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap
805805
)
806806

807807
return tensors_lora
808+
809+
def _add_modules_to_tie(self, peft_config, tied_weight_keys):
810+
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
811+
missing_keys = set(tied_weight_keys) - modules_to_save
812+
813+
peft_config.modules_to_tie = missing_keys

src/peft/tuners/tuners_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,8 @@ def inject_adapter(
704704
# in a bad (half-initialized) state.
705705
self._check_new_adapter_config(peft_config)
706706

707+
self._check_tied_modules(model, peft_config)
708+
707709
model_config = self.get_model_config(model)
708710

709711
peft_config = self._prepare_adapter_config(peft_config, model_config)
@@ -1168,6 +1170,86 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
11681170
tied_target_modules.append(target_module)
11691171
return tied_target_modules
11701172

1173+
def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]:
1174+
"""
1175+
Get the list of modules that needs to be tied
1176+
1177+
For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return
1178+
[`lm_head`]
1179+
1180+
From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563
1181+
"""
1182+
tied_weight_keys = []
1183+
if getattr(model, "_tied_weights_keys", None) is not None:
1184+
names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys]
1185+
tied_weight_keys.extend(names)
1186+
if getattr(model, "_dynamic_tied_weights_keys", None) is not None:
1187+
names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys]
1188+
tied_weight_keys.extend(names)
1189+
for name, submodule in model.named_children():
1190+
local_prefix = f"{prefix}.{name}" if prefix else name
1191+
tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix))
1192+
1193+
tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys]
1194+
1195+
return tied_weight_keys
1196+
1197+
def _add_modules_to_tie(self, peft_config, tied_weight_keys):
1198+
"""
1199+
This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this
1200+
method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this.
1201+
"""
1202+
msg = (
1203+
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
1204+
"but no implementation exists to tie the adapters. "
1205+
"This can lead to complications, for example when merging the adapter "
1206+
"or converting your model to formats other than safetensors. "
1207+
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
1208+
)
1209+
warnings.warn(msg)
1210+
1211+
def _check_tied_modules(self, model: nn.Module, peft_config):
1212+
"""
1213+
Checks if any of the tied layers are targetted via `modules_to_save`. Updates the `peft_config.modules_to_tie`
1214+
with any layers that needs to be tied
1215+
"""
1216+
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
1217+
is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save)
1218+
1219+
tied_weight_keys = self._get_tied_weight_keys(model)
1220+
1221+
if getattr(peft_config, "ensure_weight_tying", False):
1222+
if is_embedding_to_save and tied_weight_keys:
1223+
self._add_modules_to_tie(peft_config, tied_weight_keys)
1224+
1225+
elif not is_embedding_to_save and tied_weight_keys:
1226+
warnings.warn(
1227+
"You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`"
1228+
)
1229+
1230+
elif not tied_weight_keys:
1231+
warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model")
1232+
1233+
elif is_embedding_to_save and tied_weight_keys:
1234+
if hasattr(peft_config, "ensure_weight_tying"):
1235+
msg = (
1236+
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
1237+
"but `ensure_weight_tying` is not set to True. "
1238+
"This can lead to complications, for example when merging the adapter "
1239+
"or converting your model to formats other than safetensors. "
1240+
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
1241+
)
1242+
warnings.warn(msg)
1243+
else:
1244+
msg = (
1245+
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
1246+
"but no implementation exists to tie the adapters. "
1247+
"This can lead to complications, for example when merging the adapter "
1248+
"or converting your model to formats other than safetensors. "
1249+
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
1250+
)
1251+
warnings.warn(msg)
1252+
11711253
def __getattr__(self, name: str):
11721254
"""Forward missing attributes to the wrapped module."""
11731255
try:

src/peft/utils/other.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
533533
# All names of layers that may contain adapter (trainable) weights
534534
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)
535535

536-
def __init__(self, module_to_save, adapter_name):
537-
super().__init__(module_to_save, adapter_name)
536+
def __init__(self, module_to_save, adapter_name, tied_module=None):
537+
super().__init__(module_to_save, adapter_name, tied_module=tied_module)
538538

539-
def init_modules(self, adapter_name):
539+
def init_modules(self, adapter_name, **kwargs):
540540
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each
541541
self.modules_to_save = torch.nn.ModuleDict({})
542542

@@ -560,7 +560,7 @@ def _hasattr_wrapped(self, name, modules):
560560
def _getattr_wrapped(self, name, modules):
561561
return getattr(modules["modules_to_save"][self.active_adapters[0]], name)
562562

563-
def update(self, adapter_name, **kwargs):
563+
def update(self, adapter_name, tied_module=None, **kwargs):
564564
super().update(adapter_name)
565565

566566
context_manager = nullcontext()
@@ -575,7 +575,13 @@ def update(self, adapter_name, **kwargs):
575575

576576
if adapter_name not in self.modules_to_save:
577577
with context_manager:
578-
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
578+
if tied_module:
579+
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False)
580+
new_linear.weight = tied_module.weight
581+
582+
self.modules_to_save[adapter_name] = new_linear
583+
else:
584+
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
579585

580586
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
581587
old_hook = self.modules_to_save[adapter_name]._hf_hook
@@ -1427,6 +1433,20 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
14271433
activate_adapter=activate_adapter,
14281434
)
14291435

1436+
if getattr(peft_config, "modules_to_tie", None) is not None:
1437+
# Tie the modules if any tied layer is passed in `modules_to_save`.
1438+
# This should always be called after
1439+
# `_set_trainable` is called for `modules_to_save`.
1440+
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
1441+
_set_trainable(
1442+
model,
1443+
adapter_name,
1444+
inference_mode=peft_config.inference_mode,
1445+
module_names=getattr(peft_config, "modules_to_tie", None),
1446+
activate_adapter=activate_adapter,
1447+
tied_module=tied_module,
1448+
)
1449+
14301450
if getattr(peft_config, "trainable_token_indices", None) is not None:
14311451
if isinstance(peft_config.trainable_token_indices, dict):
14321452
target_layers = peft_config.trainable_token_indices

tests/test_initialization.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from peft.tuners.lora.layer import LoraLayer
7070
from peft.utils import infer_device
7171
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
72+
from peft.utils.other import ModulesToSaveWrapper
7273

7374
from .testing_utils import load_dataset_english_quotes, require_deterministic_for_xpu
7475

@@ -4792,3 +4793,140 @@ def test_key_mapping_save_old_load_new_vblora(self, old_model, new_model, tmp_pa
47924793
def test_key_mapping_save_new_load_old_vblora(self, old_model, new_model, tmp_path):
47934794
# save the new model, load it into the old model, should work without issues (forwards compatibility)
47944795
self.check_vblora_load_no_warning(new_model, old_model, tmp_path)
4796+
4797+
4798+
class TestWeightTying:
4799+
"""Test class to check the weight tying of adapters."""
4800+
4801+
torch_device = infer_device()
4802+
4803+
def get_lm_model(self, tie_weights=True):
4804+
# Mimicking a LM with embed_tokens and lm_head layers
4805+
# to test weight tying of adapters
4806+
class MyModule(nn.Module):
4807+
def __init__(self):
4808+
super().__init__()
4809+
4810+
self.embed_tokens = nn.Embedding(1000, 1000)
4811+
self.linear = nn.Linear(1000, 1000, bias=False)
4812+
4813+
class CausalLM(nn.Module):
4814+
if tie_weights:
4815+
_tied_weights_keys = ["lm_head.weight"]
4816+
4817+
def __init__(self):
4818+
super().__init__()
4819+
self.model = MyModule()
4820+
self.config = {"tie_word_embeddings": tie_weights}
4821+
self.lm_head = nn.Linear(1000, 1000, bias=False)
4822+
4823+
if tie_weights:
4824+
self.lm_head.weight = self.model.embed_tokens.weight
4825+
4826+
def prepare_inputs_for_generation(self):
4827+
return
4828+
4829+
def get_input_embeddings(self):
4830+
return self.model.embed_tokens
4831+
4832+
return CausalLM().eval().to(self.torch_device)
4833+
4834+
def test_weight_tying_tied_model_lora(self):
4835+
# If weight tying is enabled and `embed_tokens`
4836+
# is passed as a `modules_to_save`, it needs to be ensured
4837+
# that lm_head is tied to the adapter added to `embed_tokens`
4838+
4839+
model = self.get_lm_model()
4840+
4841+
embed_token_config = LoraConfig(
4842+
modules_to_save=["embed_tokens"],
4843+
target_modules=["linear"],
4844+
ensure_weight_tying=True,
4845+
)
4846+
4847+
model = get_peft_model(model, embed_token_config)
4848+
4849+
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
4850+
"Embed tokens is not added in Modules to Save"
4851+
)
4852+
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), (
4853+
"Embed tokens and LM head types are not same"
4854+
)
4855+
4856+
# Validating that all model parameters are same
4857+
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
4858+
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())
4859+
4860+
for k in embed_np.keys():
4861+
assert torch.allclose(embed_np[k], lm_head_np[k])
4862+
assert embed_np[k] is lm_head_np[k]
4863+
4864+
def test_weight_tying_non_tied_model_lora(self):
4865+
model = self.get_lm_model(tie_weights=False)
4866+
embed_token_config = LoraConfig(
4867+
modules_to_save=["embed_tokens"],
4868+
target_modules=["linear"],
4869+
ensure_weight_tying=True,
4870+
)
4871+
with pytest.warns(UserWarning, match="no tied modules were found in the model"):
4872+
model = get_peft_model(model, embed_token_config)
4873+
4874+
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
4875+
"Embed tokens is not added in Modules to Save"
4876+
)
4877+
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
4878+
"LM head is not of type nn.linear"
4879+
)
4880+
4881+
def test_not_weight_tying_tied_model_lora(self):
4882+
model = self.get_lm_model()
4883+
embed_token_config = LoraConfig(
4884+
modules_to_save=["embed_tokens"],
4885+
target_modules=["linear"],
4886+
ensure_weight_tying=False,
4887+
)
4888+
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"):
4889+
model = get_peft_model(model, embed_token_config)
4890+
4891+
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
4892+
"Embed tokens is not added in Modules to Save"
4893+
)
4894+
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
4895+
"LM head is not of type nn.linear"
4896+
)
4897+
4898+
def test_weight_tying_tied_model_no_embed_lora(self):
4899+
model = self.get_lm_model()
4900+
embed_token_config = LoraConfig(
4901+
target_modules=["linear"],
4902+
ensure_weight_tying=True,
4903+
)
4904+
4905+
with pytest.warns(UserWarning, match="no tied modules are added in `modules_to_save`"):
4906+
model = get_peft_model(model, embed_token_config)
4907+
4908+
assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding)
4909+
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear)
4910+
4911+
# Validating that all model parameters are same
4912+
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
4913+
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())
4914+
4915+
for k in embed_np.keys():
4916+
assert torch.allclose(embed_np[k], lm_head_np[k])
4917+
assert embed_np[k] is lm_head_np[k]
4918+
4919+
def test_weight_tying_tied_model_lokr(self):
4920+
model = self.get_lm_model()
4921+
4922+
embed_token_config = LoKrConfig(modules_to_save=["embed_tokens"], target_modules=["linear"])
4923+
4924+
with pytest.warns(UserWarning, match="no implementation exists to tie the adapters"):
4925+
model = get_peft_model(model, embed_token_config)
4926+
4927+
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
4928+
"Embed tokens is not added in Modules to Save"
4929+
)
4930+
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
4931+
"LM head is not of type nn.linear"
4932+
)

0 commit comments

Comments
 (0)