|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import copy |
15 | 16 |
|
16 | 17 | import pytest |
17 | 18 | import torch |
@@ -107,6 +108,84 @@ def test_get_peft_model_revision_warning(tmp_path): |
107 | 108 | _ = get_peft_model(base_model, lora_config, revision=overwrite_revision) |
108 | 109 |
|
109 | 110 |
|
| 111 | +def test_load_multiple_adapters_different_modules_to_save(tmp_path): |
| 112 | + # This tests the error described in #2422 where loading multiple adapters with different modules_to_save |
| 113 | + # attributes fails (due to a regression from #2376). |
| 114 | + |
| 115 | + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM") |
| 116 | + |
| 117 | + def peft_config(**kwargs): |
| 118 | + return LoraConfig(target_modules="all-linear", **kwargs) |
| 119 | + |
| 120 | + original_model = copy.deepcopy(model) |
| 121 | + |
| 122 | + peft_config_0 = peft_config(modules_to_save=["0.post_attention_layernorm"]) |
| 123 | + peft_config_1 = peft_config(modules_to_save=["0.post_attention_layernorm"]) |
| 124 | + peft_config_2 = peft_config(modules_to_save=["1.post_attention_layernorm"]) |
| 125 | + |
| 126 | + # Save adapter 0, nothing fancy, should be equal to base model weighs |
| 127 | + peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_0) |
| 128 | + peft_model.save_pretrained(tmp_path / "adapter_0") |
| 129 | + |
| 130 | + # Save adapter 1, modules to save weights are modified randomly, should be unique to adapter 1 |
| 131 | + peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_1) |
| 132 | + peft_model.model.model.layers[0].post_attention_layernorm.weight.data = torch.rand_like( |
| 133 | + peft_model.model.model.layers[0].post_attention_layernorm.weight.data |
| 134 | + ) |
| 135 | + adapter_1_saved = peft_model.model.model.layers[0].post_attention_layernorm.weight.data.clone() |
| 136 | + peft_model.save_pretrained(tmp_path / "adapter_1") |
| 137 | + |
| 138 | + # Save adapter 2, modules to save weights are modified randomly, should be unique to adapter 2 |
| 139 | + peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_2) |
| 140 | + peft_model.model.model.layers[1].post_attention_layernorm.weight.data = torch.rand_like( |
| 141 | + peft_model.model.model.layers[1].post_attention_layernorm.weight.data |
| 142 | + ) |
| 143 | + adapter_2_saved = peft_model.model.model.layers[1].post_attention_layernorm.weight.data.clone() |
| 144 | + peft_model.save_pretrained(tmp_path / "adapter_2") |
| 145 | + |
| 146 | + del peft_model |
| 147 | + |
| 148 | + combined_model = PeftModel.from_pretrained(original_model, tmp_path / "adapter_0", adapter_name="adapter_0") |
| 149 | + combined_model.load_adapter(tmp_path / "adapter_1", adapter_name="adapter_1") |
| 150 | + combined_model.load_adapter(tmp_path / "adapter_2", adapter_name="adapter_2") |
| 151 | + |
| 152 | + # For adapter 0 we expect every mentioned modules to save layer of this test to be equal to the original model |
| 153 | + # since we didn't modify it for adapter 0 and only adapter 0 is active. |
| 154 | + combined_model.set_adapter("adapter_0") |
| 155 | + assert torch.allclose( |
| 156 | + combined_model.model.model.layers[0].post_attention_layernorm.weight, |
| 157 | + original_model.model.layers[0].post_attention_layernorm.weight, |
| 158 | + ) |
| 159 | + assert torch.allclose( |
| 160 | + combined_model.model.model.layers[1].post_attention_layernorm.weight, |
| 161 | + original_model.model.layers[1].post_attention_layernorm.weight, |
| 162 | + ) |
| 163 | + |
| 164 | + # For adapter 1 we expect that the modified module to save 0.post_attention_layernorm is modified, the other |
| 165 | + # module to save layers mentioned above should be untouched. |
| 166 | + combined_model.set_adapter("adapter_1") |
| 167 | + assert torch.allclose( |
| 168 | + combined_model.model.model.layers[0].post_attention_layernorm.weight, |
| 169 | + adapter_1_saved, |
| 170 | + ) |
| 171 | + assert torch.allclose( |
| 172 | + combined_model.model.model.layers[1].post_attention_layernorm.weight, |
| 173 | + original_model.model.layers[1].post_attention_layernorm.weight, |
| 174 | + ) |
| 175 | + |
| 176 | + # For adapter 2 we expect its module to save layer (1.post_attention_layernorm) to be modified but the other |
| 177 | + # module to save weights should be kept original. |
| 178 | + combined_model.set_adapter("adapter_2") |
| 179 | + assert torch.allclose( |
| 180 | + combined_model.model.model.layers[0].post_attention_layernorm.weight, |
| 181 | + original_model.model.layers[0].post_attention_layernorm.weight, |
| 182 | + ) |
| 183 | + assert torch.allclose( |
| 184 | + combined_model.model.model.layers[1].post_attention_layernorm.weight, |
| 185 | + adapter_2_saved, |
| 186 | + ) |
| 187 | + |
| 188 | + |
110 | 189 | class TestModulesToSaveAttributeAccess: |
111 | 190 | """Test attribute accces on the ModulesToSaveWrapper class. |
112 | 191 |
|
|
0 commit comments