Skip to content

Commit 7320bb9

Browse files
FIX AutoPeftModels never reduce embedding size (huggingface#2427)
Resolves huggingface#2415 There was a bug in AutoPeftModels where the embedding was always resized to the vocab size of the tokenizer when the tokenizer was found. This makes sense if the vocabulary was extended, but some models like Qwen already start out with "spare" embeddings, i.e. the embedding size is larger than the vocab size. This could result in the embedding being shrunk, which in turn resulted in an error when loading the weights.
1 parent 2f063e6 commit 7320bb9

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

src/peft/auto.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,14 @@ def from_pretrained(
130130
token=token,
131131
)
132132

133-
if tokenizer_exists:
133+
if tokenizer_exists and hasattr(base_model, "get_input_embeddings"):
134134
tokenizer = AutoTokenizer.from_pretrained(
135135
pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
136136
)
137-
base_model.resize_token_embeddings(len(tokenizer))
137+
embedding_size = base_model.get_input_embeddings().weight.shape[0]
138+
if len(tokenizer) > embedding_size:
139+
# only resize if the tokenizer has a larger vocab size than there are embeddings
140+
base_model.resize_token_embeddings(len(tokenizer))
138141

139142
return cls._target_peft_class.from_pretrained(
140143
base_model,

tests/test_auto.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import tempfile
15-
import unittest
1615

1716
import torch
17+
from transformers import AutoModelForCausalLM, AutoTokenizer
1818

1919
from peft import (
2020
AutoPeftModel,
@@ -24,18 +24,20 @@
2424
AutoPeftModelForSeq2SeqLM,
2525
AutoPeftModelForSequenceClassification,
2626
AutoPeftModelForTokenClassification,
27+
LoraConfig,
2728
PeftModel,
2829
PeftModelForCausalLM,
2930
PeftModelForFeatureExtraction,
3031
PeftModelForQuestionAnswering,
3132
PeftModelForSeq2SeqLM,
3233
PeftModelForSequenceClassification,
3334
PeftModelForTokenClassification,
35+
get_peft_model,
3436
)
3537
from peft.utils import infer_device
3638

3739

38-
class PeftAutoModelTester(unittest.TestCase):
40+
class TestPeftAutoModel:
3941
dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16
4042

4143
def test_peft_causal_lm(self):
@@ -207,3 +209,23 @@ def test_peft_whisper(self):
207209
is_trainable = False
208210
# This should work
209211
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
212+
213+
def test_embedding_size_not_reduced_if_greater_vocab_size(self, tmp_path):
214+
# See 2415
215+
# There was a bug in AutoPeftModels where the embedding was always resized to the vocab size of the tokenizer
216+
# when the tokenizer was found. This makes sense if the vocabulary was extended, but some models like Qwen
217+
# already start out with "spare" embeddings, i.e. the embedding size is larger than the vocab size. This could
218+
# result in the embedding being shrunk, which in turn resulted in an error when loading the weights.
219+
220+
# first create a checkpoint; it is important that the tokenizer is also saved in the same location
221+
model_id = "Qwen/Qwen2-0.5B"
222+
model = AutoModelForCausalLM.from_pretrained(model_id)
223+
tokenizer = AutoTokenizer.from_pretrained(model_id)
224+
model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head", "embed_token"]))
225+
model.save_pretrained(tmp_path)
226+
tokenizer.save_pretrained(tmp_path)
227+
228+
# does not raise; without the fix, it raises:
229+
# > size mismatch for base_model.model.lm_head.modules_to_save.default.weight: copying a param with shape
230+
# torch.Size([151936, 896]) from checkpoint, the shape in current model is torch.Size([151646, 896]).
231+
AutoPeftModelForCausalLM.from_pretrained(tmp_path)

0 commit comments

Comments
 (0)