Skip to content

Commit a91ec33

Browse files
authored
Fix not detecting regex-targeted embedding layer (#2649)
This issue was found in PR #2638 and is defined thusly: > When calling `get_peft_model_state_dict(..., save_embedding_layers="auto")` we check if the > embedding layer is targetted to determine if the embedding layers need saving. This is not > done when `PeftConfig.target_modules` is a regex-string, potentially missing to save embeddings. This is fixed by adding a check similar to the existing query of whether `EMBEDDING_LAYER_NAMES` is a subset of the defined target modules, only that the regex matching from `BaseTuner.inject_adapter` is used. To avoid code duplication, the matching was moved to its own utility function `match_target_against_key`. The main complication was to define the test-cases as it was non-trivial to find what the meaning of `save_embedding_layers="auto"` entails. I've assembled a list of cases that I think are correct in the corresponding unit test.
1 parent 25e5c6b commit a91ec33

File tree

6 files changed

+156
-58
lines changed

6 files changed

+156
-58
lines changed

src/peft/tuners/tuners_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
SEQ_CLS_HEAD_NAMES,
3939
)
4040
from peft.utils.integrations import init_empty_weights
41-
from peft.utils.other import AuxiliaryTrainingWrapper, set_additional_trainable_modules
41+
from peft.utils.other import AuxiliaryTrainingWrapper, match_target_against_key, set_additional_trainable_modules
4242
from peft.utils.peft_types import PeftType, TaskType
4343

4444
from ..config import PeftConfig
@@ -1133,7 +1133,7 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
11331133
return False
11341134

11351135
if isinstance(config.target_modules, str):
1136-
target_module_found = re.fullmatch(config.target_modules, key)
1136+
target_module_found = match_target_against_key(config.target_modules, key)
11371137
elif key in config.target_modules:
11381138
# this module is specified directly in target_modules
11391139
target_module_found = True

src/peft/utils/other.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,14 @@ def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Option
12291229
return exists
12301230

12311231

1232+
def match_target_against_key(target_pattern: str, key: str):
1233+
"""Backing function for `target_modules` config parameter.
1234+
1235+
Having this as its own function ensures that target key matching can be implemented in the same way everywhere.
1236+
"""
1237+
return re.fullmatch(target_pattern, key)
1238+
1239+
12321240
def get_pattern_key(pattern_keys: Sequence[str], key_to_match: str) -> str:
12331241
"""Match a substring of key_to_match in pattern keys"""
12341242
for key in pattern_keys:

src/peft/utils/save_and_load.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828

2929
from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING
3030

31+
from .constants import INCLUDE_LINEAR_LAYERS_SHORTHAND
3132
from .other import (
3233
EMBEDDING_LAYER_NAMES,
3334
SAFETENSORS_WEIGHTS_NAME,
3435
WEIGHTS_NAME,
3536
AuxiliaryTrainingWrapper,
3637
check_file_exists_on_hf_hub,
3738
infer_device,
39+
match_target_against_key,
3840
)
3941
from .peft_types import PeftType
4042

@@ -235,23 +237,35 @@ def renamed_dora_weights(k):
235237
)
236238

237239
# DEAL WITH EMBEDDINGS
238-
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
239-
is_embedding_in_target_modules = False
240+
#
241+
# save_embedding_layer="auto" needs to check the following logic:
242+
#
243+
# - when vocab size was NOT changed, embeddings should be saved only when targeted
244+
# but not when
245+
# - using PeftType.TRAINABLE_TOKENS
246+
# - LoRA using trainable_token_indices (since their goal is to space-efficient)
247+
# but
248+
# - when vocab size was changed, embeddings should be saved automatically regardless to cover this
249+
# scenario: 1) fine-tune embedding, 2) resize embedding, 3) train with trainable tokens
250+
#
240251
embedding_is_targeted = False
241252
if hasattr(config, "target_modules"):
242-
if isinstance(config.target_modules, str):
243-
# TODO: implement this; note: this change is not directly related to the PR, the bug already existed b4
244-
pass
253+
if isinstance(config.target_modules, str) and config.target_modules != INCLUDE_LINEAR_LAYERS_SHORTHAND:
254+
embedding_is_targeted = any(
255+
match_target_against_key(config.target_modules, k)
256+
for k, _ in model.get_base_model().named_modules()
257+
if any(re.match(rf"(.*\.)?{e}$", k) for e in EMBEDDING_LAYER_NAMES)
258+
)
245259
elif config.target_modules:
246260
embedding_is_targeted = any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
247-
if (
248-
save_embedding_layers == "auto"
249-
and hasattr(config, "target_modules")
250-
and embedding_is_targeted
251-
and config.peft_type != PeftType.TRAINABLE_TOKENS
252-
):
261+
262+
using_trainable_tokens = (
263+
config.peft_type == PeftType.TRAINABLE_TOKENS or getattr(config, "trainable_token_indices", None) is not None
264+
)
265+
266+
if save_embedding_layers == "auto" and embedding_is_targeted and not using_trainable_tokens:
253267
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
254-
save_embedding_layers = is_embedding_in_target_modules = True
268+
save_embedding_layers = True
255269
elif save_embedding_layers == "auto":
256270
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
257271
model_id = getattr(config, "base_model_name_or_path", None)
@@ -289,9 +303,10 @@ def renamed_dora_weights(k):
289303

290304
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
291305
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
292-
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
293-
# support from version >= 0.6.2
294-
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
306+
# Either the layer is not targeted, then it must have been resized and needs saving. Or it is targeted and
307+
# therefore has a valid base layer, then we'll save it as well.
308+
if not embedding_is_targeted or has_valid_embedding_base_layer(layer):
309+
embedding_module_name = get_embedding_layer_name(model, layer, embedding_is_targeted)
295310
if embedding_module_name:
296311
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
297312
elif save_embedding_layers:

tests/test_custom_models.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,52 +2392,62 @@ def test_non_existing_model_card(self):
23922392
assert len(model_card) > 1000
23932393

23942394
@pytest.mark.parametrize("save_embedding_layers", ["auto", True, False])
2395-
def test_targeting_lora_to_embedding_layer(self, save_embedding_layers):
2395+
@pytest.mark.parametrize(
2396+
"peft_config",
2397+
[
2398+
(LoraConfig(target_modules=["lin0", "embed_tokens"], init_lora_weights=False)),
2399+
(LoraConfig(target_modules=r"^embed_tokens", init_lora_weights=False)),
2400+
],
2401+
)
2402+
def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_layers, tmp_path, peft_config):
23962403
model = ModelEmbWithEmbeddingUtils()
2397-
config = LoraConfig(target_modules=["embed_tokens", "lin0"], init_lora_weights=False)
2398-
model = get_peft_model(model, config)
2404+
model = get_peft_model(model, peft_config)
23992405

2400-
with tempfile.TemporaryDirectory() as tmp_dirname:
2401-
if save_embedding_layers == "auto":
2402-
# assert warning
2403-
msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
2404-
with pytest.warns(UserWarning, match=msg_start):
2405-
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
2406-
else:
2407-
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
2408-
from safetensors.torch import load_file as safe_load_file
2409-
2410-
state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors"))
2411-
if save_embedding_layers in ["auto", True]:
2412-
assert "base_model.model.embed_tokens.base_layer.weight" in state_dict
2413-
assert torch.allclose(
2414-
model.base_model.model.embed_tokens.base_layer.weight,
2415-
state_dict["base_model.model.embed_tokens.base_layer.weight"],
2416-
)
2417-
else:
2418-
assert "base_model.model.embed_tokens.base_layer.weight" not in state_dict
2419-
del state_dict
2406+
if save_embedding_layers == "auto":
2407+
# assert warning
2408+
msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
2409+
with pytest.warns(UserWarning, match=msg_start):
2410+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
2411+
else:
2412+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
2413+
2414+
state_dict = safe_load_file(tmp_path / "adapter_model.safetensors")
2415+
contains_embedding = "base_model.model.embed_tokens.base_layer.weight" in state_dict
2416+
2417+
if save_embedding_layers in ["auto", True]:
2418+
assert contains_embedding
2419+
assert torch.allclose(
2420+
model.base_model.model.embed_tokens.base_layer.weight,
2421+
state_dict["base_model.model.embed_tokens.base_layer.weight"],
2422+
)
2423+
else:
2424+
assert not contains_embedding
24202425

24212426
@pytest.mark.parametrize("save_embedding_layers", ["auto", True, False])
2422-
def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding_layers):
2427+
@pytest.mark.parametrize(
2428+
"peft_config",
2429+
[
2430+
(LoraConfig(target_modules=["lin0", "emb"], init_lora_weights=False)),
2431+
(LoraConfig(target_modules=r"^emb", init_lora_weights=False)),
2432+
],
2433+
)
2434+
def test_save_pretrained_targeting_lora_to_embedding_layer_non_transformers(
2435+
self, save_embedding_layers, tmp_path, peft_config
2436+
):
24232437
model = ModelEmbConv1D()
2424-
config = LoraConfig(target_modules=["emb", "lin0"], init_lora_weights=False)
2425-
model = get_peft_model(model, config)
2426-
2427-
with tempfile.TemporaryDirectory() as tmp_dirname:
2428-
if save_embedding_layers is True:
2429-
with pytest.warns(
2430-
UserWarning,
2431-
match=r"Could not identify embedding layer\(s\) because the model is not a 🤗 transformers model\.",
2432-
):
2433-
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
2434-
else:
2435-
model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers)
2436-
from safetensors.torch import load_file as safe_load_file
2438+
model = get_peft_model(model, peft_config)
2439+
2440+
if save_embedding_layers is True:
2441+
with pytest.warns(
2442+
UserWarning,
2443+
match=r"Could not identify embedding layer\(s\) because the model is not a 🤗 transformers model\.",
2444+
):
2445+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
2446+
else:
2447+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
24372448

2438-
state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors"))
2439-
assert "base_model.model.emb.base_layer.weight" not in state_dict
2440-
del state_dict
2449+
state_dict = safe_load_file(tmp_path / "adapter_model.safetensors")
2450+
assert "base_model.model.emb.base_layer.weight" not in state_dict
24412451

24422452
def test_load_resized_embedding_ignore_mismatched_sizes(self):
24432453
# issue #1605

tests/test_decoder_models.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
import torch
20+
from safetensors.torch import load_file as safe_load_file
2021
from transformers import (
2122
AutoModelForCausalLM,
2223
AutoTokenizer,
@@ -46,7 +47,7 @@
4647
get_peft_model,
4748
)
4849

49-
from .testing_common import PeftCommonTester
50+
from .testing_common import PeftCommonTester, hub_online_once
5051
from .testing_utils import device_count, load_dataset_english_quotes, set_init_weights_false
5152

5253

@@ -680,3 +681,38 @@ def process(samples):
680681
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
681682
)
682683
trainer.train()
684+
685+
@pytest.mark.parametrize("save_embedding_layers", ["auto", True, False])
686+
@pytest.mark.parametrize(
687+
"peft_config",
688+
[
689+
(LoraConfig(target_modules=["lin0", "embed_tokens"], init_lora_weights=False)),
690+
(LoraConfig(target_modules=r".*\.embed_tokens", init_lora_weights=False)),
691+
],
692+
)
693+
def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_layers, tmp_path, peft_config):
694+
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
695+
696+
with hub_online_once(model_id):
697+
model = AutoModelForCausalLM.from_pretrained(model_id)
698+
model = get_peft_model(model, peft_config)
699+
700+
if save_embedding_layers == "auto":
701+
# assert warning
702+
msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
703+
with pytest.warns(UserWarning, match=msg_start):
704+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
705+
else:
706+
model.save_pretrained(tmp_path, save_embedding_layers=save_embedding_layers)
707+
708+
state_dict = safe_load_file(tmp_path / "adapter_model.safetensors")
709+
contains_embedding = "base_model.model.model.embed_tokens.base_layer.weight" in state_dict
710+
711+
if save_embedding_layers in ["auto", True]:
712+
assert contains_embedding
713+
assert torch.allclose(
714+
model.base_model.model.model.embed_tokens.base_layer.weight,
715+
state_dict["base_model.model.model.embed_tokens.base_layer.weight"],
716+
)
717+
else:
718+
assert not contains_embedding

tests/test_trainable_tokens.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pytest
2020
import torch
21+
from safetensors.torch import load_file as safe_load_file
2122
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
2223

2324
from peft import AutoPeftModel, LoraConfig, PeftModel, TrainableTokensConfig, get_peft_model
@@ -885,3 +886,31 @@ def test_embedding_name_is_used_when_given_combined(self, model_embed_multiple):
885886

886887
assert isinstance(peft_model.model.embed_in_2, TrainableTokensWrapper)
887888
assert not isinstance(peft_model.model.embed_in, TrainableTokensWrapper)
889+
890+
@pytest.mark.parametrize("resize_embedding", [True, False])
891+
@pytest.mark.parametrize(
892+
"peft_config",
893+
[
894+
LoraConfig(target_modules="all-linear", trainable_token_indices=[1, 2, 3]),
895+
TrainableTokensConfig(target_modules=None, token_indices=[1, 2, 3]),
896+
],
897+
)
898+
def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_path):
899+
# make sure that embeddings are saved alongside trainable token weights but only when
900+
# the we detect the embedding to be resized (as detected by save_embedding_layers="auto")
901+
if resize_embedding:
902+
model.resize_token_embeddings(model.config.vocab_size + 2)
903+
peft_model = get_peft_model(model, peft_config)
904+
905+
peft_model.save_pretrained(tmp_path, save_embedding_layers="auto")
906+
state_dict = safe_load_file(tmp_path / "adapter_model.safetensors")
907+
908+
if isinstance(peft_config, TrainableTokensConfig):
909+
contains_embedding = "base_model.model.model.embed_tokens.base_layer.weight" in state_dict
910+
else:
911+
contains_embedding = "base_model.model.model.embed_tokens.token_adapter.base_layer.weight" in state_dict
912+
913+
if resize_embedding:
914+
assert contains_embedding
915+
else:
916+
assert not contains_embedding

0 commit comments

Comments
 (0)