Skip to content

Commit 48e0c5d

Browse files
githubnemosaeid93BenjaminBossan
authored
Fix huggingface#2422: Modules to save with multiple adapters (huggingface#2430)
Using multiple adapters with different `modules_to_save` values leads to a scenario where it is implicitly assumed that each `ModulesToSaveWrapper` has a module for every loaded adapter. Since the adapters have different `modules_to_save` values this is not the case and retrieving the state dict fails with a key lookup error. In addition to that, after disabling a `ModulesToSaveWrapper`, setting the adapter as active does not re-enable said adapter. --------- Co-authored-by: Saeid Ghafouri <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent b2b34fd commit 48e0c5d

File tree

4 files changed

+100
-3
lines changed

4 files changed

+100
-3
lines changed

src/peft/tuners/tuners_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ def inject_adapter(
473473
if not key:
474474
continue
475475
# Check for modules_to_save in case
476+
#
477+
# Note that this is redundant with PeftModel.set_additional_trainable_models but might be necessary
478+
# when calling inject_adapter without a PEFT model. This is outdated as it only focuses on
479+
# ModulesToSaveWrapper and ignores other potentially configured AuxiliaryTrainingWrapper instances.
480+
#
481+
# TODO: determine if there's a good reason for this and refactor to support AuxiliaryTrainingWrapper,
482+
# or remove if superfluous.
476483
if _check_for_modules_to_save and any(
477484
key.endswith(module_to_save) for module_to_save in peft_config.modules_to_save
478485
):

src/peft/utils/other.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,10 @@ def update(self, adapter_name, **kwargs):
499499
add_hook_to_module(self.modules_to_save[adapter_name], new_hook)
500500

501501
self.original_module.requires_grad_(False)
502+
503+
# note that there currently cannot be more than one active adapter for the same layer with modules to save
504+
# since there would be no clear way to decide which adapter's weights are the correct ones. therefore we
505+
# assume that there is only one active adapter. this precondition is enforced by _set_adapter.
502506
if adapter_name == self.active_adapter:
503507
self.modules_to_save[adapter_name].requires_grad_(True)
504508

@@ -550,6 +554,10 @@ def adapter_state_dict_load_map(self, adapter_name):
550554
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.adapter_state_dict(adapter_name)}
551555

552556
def adapter_state_dict(self, adapter_name):
557+
if adapter_name not in self._adapters:
558+
# In caes of multiple adapters, each bringing their own modules to save, each
559+
# ModulesToSaveWrapper will be queried but not every wrapper is obliged to serve the same adapters.
560+
return {}
553561
return self.modules_to_save[adapter_name].state_dict()
554562

555563
def unload_and_optionally_merge_module(
@@ -732,6 +740,7 @@ def _set_trainable(
732740
found_modules = set()
733741
# disable removal of duplicates to support targeting tied weights
734742
key_list = [key for key, _ in model.named_modules(remove_duplicate=False)]
743+
735744
for key in key_list:
736745
target_module_found = any(key.endswith(target_key) for target_key in module_names)
737746
if target_module_found:
@@ -776,6 +785,7 @@ def check_adapter_name(adapter_name):
776785
# if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
777786
# module
778787
if adapter_name in module._adapters:
788+
module.enable_adapters(True)
779789
module.set_adapter(adapter_name)
780790
else:
781791
module.enable_adapters(False)

tests/test_custom_models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
12361236
outputs_base = model(**X)
12371237
if issubclass(config_cls, (FourierFTConfig, TrainableTokensConfig)):
12381238
config_kwargs = config_kwargs.copy()
1239+
# override the default value and make PEFT operation a no-op
12391240
config_kwargs["init_weights"] = True
12401241
config = config_cls(
12411242
base_model_name_or_path=model_id,
@@ -1255,9 +1256,9 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
12551256
model.train()
12561257
# EmbConv1D is slow to learn for some reason
12571258
lr = 0.01 if model_id != "EmbConv1D" else 1.0
1258-
if isinstance(config_cls, LNTuningConfig):
1259-
# LayerNorm tuning is slow to learn
1260-
lr = 1.0
1259+
if isinstance(config, TrainableTokensConfig):
1260+
# TrainableTokens is only changing a small subset, so we need a higher lr to see the difference
1261+
lr = 2.0
12611262
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
12621263

12631264
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry

tests/test_other.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516

1617
import pytest
1718
import torch
@@ -107,6 +108,84 @@ def test_get_peft_model_revision_warning(tmp_path):
107108
_ = get_peft_model(base_model, lora_config, revision=overwrite_revision)
108109

109110

111+
def test_load_multiple_adapters_different_modules_to_save(tmp_path):
112+
# This tests the error described in #2422 where loading multiple adapters with different modules_to_save
113+
# attributes fails (due to a regression from #2376).
114+
115+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
116+
117+
def peft_config(**kwargs):
118+
return LoraConfig(target_modules="all-linear", **kwargs)
119+
120+
original_model = copy.deepcopy(model)
121+
122+
peft_config_0 = peft_config(modules_to_save=["0.post_attention_layernorm"])
123+
peft_config_1 = peft_config(modules_to_save=["0.post_attention_layernorm"])
124+
peft_config_2 = peft_config(modules_to_save=["1.post_attention_layernorm"])
125+
126+
# Save adapter 0, nothing fancy, should be equal to base model weighs
127+
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_0)
128+
peft_model.save_pretrained(tmp_path / "adapter_0")
129+
130+
# Save adapter 1, modules to save weights are modified randomly, should be unique to adapter 1
131+
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_1)
132+
peft_model.model.model.layers[0].post_attention_layernorm.weight.data = torch.rand_like(
133+
peft_model.model.model.layers[0].post_attention_layernorm.weight.data
134+
)
135+
adapter_1_saved = peft_model.model.model.layers[0].post_attention_layernorm.weight.data.clone()
136+
peft_model.save_pretrained(tmp_path / "adapter_1")
137+
138+
# Save adapter 2, modules to save weights are modified randomly, should be unique to adapter 2
139+
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_2)
140+
peft_model.model.model.layers[1].post_attention_layernorm.weight.data = torch.rand_like(
141+
peft_model.model.model.layers[1].post_attention_layernorm.weight.data
142+
)
143+
adapter_2_saved = peft_model.model.model.layers[1].post_attention_layernorm.weight.data.clone()
144+
peft_model.save_pretrained(tmp_path / "adapter_2")
145+
146+
del peft_model
147+
148+
combined_model = PeftModel.from_pretrained(original_model, tmp_path / "adapter_0", adapter_name="adapter_0")
149+
combined_model.load_adapter(tmp_path / "adapter_1", adapter_name="adapter_1")
150+
combined_model.load_adapter(tmp_path / "adapter_2", adapter_name="adapter_2")
151+
152+
# For adapter 0 we expect every mentioned modules to save layer of this test to be equal to the original model
153+
# since we didn't modify it for adapter 0 and only adapter 0 is active.
154+
combined_model.set_adapter("adapter_0")
155+
assert torch.allclose(
156+
combined_model.model.model.layers[0].post_attention_layernorm.weight,
157+
original_model.model.layers[0].post_attention_layernorm.weight,
158+
)
159+
assert torch.allclose(
160+
combined_model.model.model.layers[1].post_attention_layernorm.weight,
161+
original_model.model.layers[1].post_attention_layernorm.weight,
162+
)
163+
164+
# For adapter 1 we expect that the modified module to save 0.post_attention_layernorm is modified, the other
165+
# module to save layers mentioned above should be untouched.
166+
combined_model.set_adapter("adapter_1")
167+
assert torch.allclose(
168+
combined_model.model.model.layers[0].post_attention_layernorm.weight,
169+
adapter_1_saved,
170+
)
171+
assert torch.allclose(
172+
combined_model.model.model.layers[1].post_attention_layernorm.weight,
173+
original_model.model.layers[1].post_attention_layernorm.weight,
174+
)
175+
176+
# For adapter 2 we expect its module to save layer (1.post_attention_layernorm) to be modified but the other
177+
# module to save weights should be kept original.
178+
combined_model.set_adapter("adapter_2")
179+
assert torch.allclose(
180+
combined_model.model.model.layers[0].post_attention_layernorm.weight,
181+
original_model.model.layers[0].post_attention_layernorm.weight,
182+
)
183+
assert torch.allclose(
184+
combined_model.model.model.layers[1].post_attention_layernorm.weight,
185+
adapter_2_saved,
186+
)
187+
188+
110189
class TestModulesToSaveAttributeAccess:
111190
"""Test attribute accces on the ModulesToSaveWrapper class.
112191

0 commit comments

Comments
 (0)