|
69 | 69 | from peft.tuners.lora.layer import LoraLayer |
70 | 70 | from peft.utils import infer_device |
71 | 71 | from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap |
| 72 | +from peft.utils.other import ModulesToSaveWrapper |
72 | 73 |
|
73 | 74 | from .testing_utils import load_dataset_english_quotes, require_deterministic_for_xpu |
74 | 75 |
|
@@ -4792,3 +4793,140 @@ def test_key_mapping_save_old_load_new_vblora(self, old_model, new_model, tmp_pa |
4792 | 4793 | def test_key_mapping_save_new_load_old_vblora(self, old_model, new_model, tmp_path): |
4793 | 4794 | # save the new model, load it into the old model, should work without issues (forwards compatibility) |
4794 | 4795 | self.check_vblora_load_no_warning(new_model, old_model, tmp_path) |
| 4796 | + |
| 4797 | + |
| 4798 | +class TestWeightTying: |
| 4799 | + """Test class to check the weight tying of adapters.""" |
| 4800 | + |
| 4801 | + torch_device = infer_device() |
| 4802 | + |
| 4803 | + def get_lm_model(self, tie_weights=True): |
| 4804 | + # Mimicking a LM with embed_tokens and lm_head layers |
| 4805 | + # to test weight tying of adapters |
| 4806 | + class MyModule(nn.Module): |
| 4807 | + def __init__(self): |
| 4808 | + super().__init__() |
| 4809 | + |
| 4810 | + self.embed_tokens = nn.Embedding(1000, 1000) |
| 4811 | + self.linear = nn.Linear(1000, 1000, bias=False) |
| 4812 | + |
| 4813 | + class CausalLM(nn.Module): |
| 4814 | + if tie_weights: |
| 4815 | + _tied_weights_keys = ["lm_head.weight"] |
| 4816 | + |
| 4817 | + def __init__(self): |
| 4818 | + super().__init__() |
| 4819 | + self.model = MyModule() |
| 4820 | + self.config = {"tie_word_embeddings": tie_weights} |
| 4821 | + self.lm_head = nn.Linear(1000, 1000, bias=False) |
| 4822 | + |
| 4823 | + if tie_weights: |
| 4824 | + self.lm_head.weight = self.model.embed_tokens.weight |
| 4825 | + |
| 4826 | + def prepare_inputs_for_generation(self): |
| 4827 | + return |
| 4828 | + |
| 4829 | + def get_input_embeddings(self): |
| 4830 | + return self.model.embed_tokens |
| 4831 | + |
| 4832 | + return CausalLM().eval().to(self.torch_device) |
| 4833 | + |
| 4834 | + def test_weight_tying_tied_model_lora(self): |
| 4835 | + # If weight tying is enabled and `embed_tokens` |
| 4836 | + # is passed as a `modules_to_save`, it needs to be ensured |
| 4837 | + # that lm_head is tied to the adapter added to `embed_tokens` |
| 4838 | + |
| 4839 | + model = self.get_lm_model() |
| 4840 | + |
| 4841 | + embed_token_config = LoraConfig( |
| 4842 | + modules_to_save=["embed_tokens"], |
| 4843 | + target_modules=["linear"], |
| 4844 | + ensure_weight_tying=True, |
| 4845 | + ) |
| 4846 | + |
| 4847 | + model = get_peft_model(model, embed_token_config) |
| 4848 | + |
| 4849 | + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( |
| 4850 | + "Embed tokens is not added in Modules to Save" |
| 4851 | + ) |
| 4852 | + assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( |
| 4853 | + "Embed tokens and LM head types are not same" |
| 4854 | + ) |
| 4855 | + |
| 4856 | + # Validating that all model parameters are same |
| 4857 | + embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) |
| 4858 | + lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) |
| 4859 | + |
| 4860 | + for k in embed_np.keys(): |
| 4861 | + assert torch.allclose(embed_np[k], lm_head_np[k]) |
| 4862 | + assert embed_np[k] is lm_head_np[k] |
| 4863 | + |
| 4864 | + def test_weight_tying_non_tied_model_lora(self): |
| 4865 | + model = self.get_lm_model(tie_weights=False) |
| 4866 | + embed_token_config = LoraConfig( |
| 4867 | + modules_to_save=["embed_tokens"], |
| 4868 | + target_modules=["linear"], |
| 4869 | + ensure_weight_tying=True, |
| 4870 | + ) |
| 4871 | + with pytest.warns(UserWarning, match="no tied modules were found in the model"): |
| 4872 | + model = get_peft_model(model, embed_token_config) |
| 4873 | + |
| 4874 | + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( |
| 4875 | + "Embed tokens is not added in Modules to Save" |
| 4876 | + ) |
| 4877 | + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( |
| 4878 | + "LM head is not of type nn.linear" |
| 4879 | + ) |
| 4880 | + |
| 4881 | + def test_not_weight_tying_tied_model_lora(self): |
| 4882 | + model = self.get_lm_model() |
| 4883 | + embed_token_config = LoraConfig( |
| 4884 | + modules_to_save=["embed_tokens"], |
| 4885 | + target_modules=["linear"], |
| 4886 | + ensure_weight_tying=False, |
| 4887 | + ) |
| 4888 | + with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): |
| 4889 | + model = get_peft_model(model, embed_token_config) |
| 4890 | + |
| 4891 | + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( |
| 4892 | + "Embed tokens is not added in Modules to Save" |
| 4893 | + ) |
| 4894 | + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( |
| 4895 | + "LM head is not of type nn.linear" |
| 4896 | + ) |
| 4897 | + |
| 4898 | + def test_weight_tying_tied_model_no_embed_lora(self): |
| 4899 | + model = self.get_lm_model() |
| 4900 | + embed_token_config = LoraConfig( |
| 4901 | + target_modules=["linear"], |
| 4902 | + ensure_weight_tying=True, |
| 4903 | + ) |
| 4904 | + |
| 4905 | + with pytest.warns(UserWarning, match="no tied modules are added in `modules_to_save`"): |
| 4906 | + model = get_peft_model(model, embed_token_config) |
| 4907 | + |
| 4908 | + assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) |
| 4909 | + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) |
| 4910 | + |
| 4911 | + # Validating that all model parameters are same |
| 4912 | + embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) |
| 4913 | + lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) |
| 4914 | + |
| 4915 | + for k in embed_np.keys(): |
| 4916 | + assert torch.allclose(embed_np[k], lm_head_np[k]) |
| 4917 | + assert embed_np[k] is lm_head_np[k] |
| 4918 | + |
| 4919 | + def test_weight_tying_tied_model_lokr(self): |
| 4920 | + model = self.get_lm_model() |
| 4921 | + |
| 4922 | + embed_token_config = LoKrConfig(modules_to_save=["embed_tokens"], target_modules=["linear"]) |
| 4923 | + |
| 4924 | + with pytest.warns(UserWarning, match="no implementation exists to tie the adapters"): |
| 4925 | + model = get_peft_model(model, embed_token_config) |
| 4926 | + |
| 4927 | + assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( |
| 4928 | + "Embed tokens is not added in Modules to Save" |
| 4929 | + ) |
| 4930 | + assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( |
| 4931 | + "LM head is not of type nn.linear" |
| 4932 | + ) |
0 commit comments