diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index bd0ca16f865f..8afff36eb086 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -235,13 +235,29 @@ def load_adapter( ) if incompatible_keys is not None: - # check only for unexpected keys + err_msg = "" + origin_name = peft_model_id if peft_model_id is not None else "state_dict" + # Check for unexpected keys. if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: - logger.warning( - f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: " - f" {incompatible_keys.unexpected_keys}. " + err_msg = ( + f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: " + f"{', '.join(incompatible_keys.unexpected_keys)}. " ) + # Check for missing keys. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + # Filter missing keys specific to the current adapter, as missing base model keys are expected. + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + err_msg += ( + f"Loading adapter weights from {origin_name} led to missing keys in the model: " + f"{', '.join(lora_missing_keys)}" + ) + + if err_msg: + logger.warning(err_msg) + # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. if ( (getattr(self, "hf_device_map", None) is not None) diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index a80919dc61cf..aebf2b295267 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -20,8 +20,9 @@ from huggingface_hub import hf_hub_download from packaging import version -from transformers import AutoModelForCausalLM, OPTForCausalLM +from transformers import AutoModelForCausalLM, OPTForCausalLM, logging from transformers.testing_utils import ( + CaptureLogger, require_bitsandbytes, require_peft, require_torch, @@ -72,9 +73,15 @@ def test_peft_from_pretrained(self): This checks if we pass a remote folder that contains an adapter config and adapter weights, it should correctly load a model that has adapters injected on it. """ + logger = logging.get_logger("transformers.integrations.peft") + for model_id in self.peft_test_model_ids: for transformers_class in self.transformers_test_model_classes: - peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + with CaptureLogger(logger) as cl: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + # ensure that under normal circumstances, there are no warnings about keys + self.assertNotIn("unexpected keys", cl.out) + self.assertNotIn("missing keys", cl.out) self.assertTrue(self._check_lora_correctly_converted(peft_model)) self.assertTrue(peft_model._hf_peft_config_loaded) @@ -548,3 +555,70 @@ def test_peft_from_pretrained_hub_kwargs(self): model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs) self.assertTrue(self._check_lora_correctly_converted(model)) + + def test_peft_from_pretrained_unexpected_keys_warning(self): + """ + Test for warning when loading a PEFT checkpoint with unexpected keys. + """ + from peft import LoraConfig + + logger = logging.get_logger("transformers.integrations.peft") + + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig() + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + dummy_state_dict = torch.load(state_dict_path) + + # add unexpected key + dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values())) + + with CaptureLogger(logger) as cl: + model.load_adapter( + adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False + ) + + msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar" + self.assertIn(msg, cl.out) + + def test_peft_from_pretrained_missing_keys_warning(self): + """ + Test for warning when loading a PEFT checkpoint with missing keys. + """ + from peft import LoraConfig + + logger = logging.get_logger("transformers.integrations.peft") + + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig() + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + dummy_state_dict = torch.load(state_dict_path) + + # remove a key so that we have missing keys + key = next(iter(dummy_state_dict.keys())) + del dummy_state_dict[key] + + with CaptureLogger(logger) as cl: + model.load_adapter( + adapter_state_dict=dummy_state_dict, + peft_config=peft_config, + low_cpu_mem_usage=False, + adapter_name="other", + ) + + # Here we need to adjust the key name a bit to account for PEFT-specific naming. + # 1. Remove PEFT-specific prefix + # If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix) + peft_prefix = "base_model.model." + key = key[len(peft_prefix) :] + # 2. Insert adapter name + prefix, _, suffix = key.rpartition(".") + key = f"{prefix}.other.{suffix}" + + msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}" + self.assertIn(msg, cl.out)