diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 3b7e636416..47d6f72c8d 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -142,6 +142,8 @@ def __init__( if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"): self.base_model.config.pretraining_tp = 1 + self._adapters_disabled = False + @property def peft_config(self) -> dict[str, PeftConfig]: if self._is_prompt_learning: @@ -167,6 +169,17 @@ def active_adapters(self) -> list[str]: adapters = [adapters] return adapters + @property + def has_active_enabled_adapter(self) -> bool: + """Reflects whether the adapters are purposefully disabled (via disable_adapter) or if there + are no active adapters (enabled but inactive). They are two separate mechanisms but sometimes it is helpful to + know whether the model has any active/enabled adapter at all. + """ + if self.peft_config[self.active_adapter].is_prompt_learning: + return not self._adapters_disabled + + return not self._adapters_disabled or not self.active_adapters + @peft_config.setter def peft_config(self, value: dict[str, PeftConfig]): if self._is_prompt_learning: @@ -890,7 +903,7 @@ def __getattr__(self, name: str): def _enable_peft_forward_hooks(self, *args, **kwargs): # If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this # runs without any changes - if hasattr(self.base_model, "_enable_peft_forward_hooks"): + if hasattr(self.base_model, "_enable_peft_forward_hooks") and self.has_active_enabled_adapter: with self.base_model._enable_peft_forward_hooks(*args, **kwargs): yield return @@ -940,17 +953,21 @@ def disable_adapter(self): self.forward = self.base_model.forward old_prepare_inputs_for_generation = self.prepare_inputs_for_generation self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation + self._adapters_disabled = True yield finally: self.forward = old_forward self.prepare_inputs_for_generation = old_prepare_inputs_for_generation + self._adapters_disabled = False elif self.peft_config[self.active_adapter].is_adaption_prompt: try: self.base_model.disable_adapter_layers() + self._adapters_disabled = True yield finally: self.base_model.enable_adapter_layers() + self._adapters_disabled = False else: # LoRA, LoHa, etc. model_status = self.get_model_status() @@ -962,11 +979,13 @@ def disable_adapter(self): ) try: self.base_model.disable_adapter_layers() + self._adapters_disabled = True yield finally: if model_status.enabled is not False: # model_status.enabled is `True` or `"irregular"` self.base_model.enable_adapter_layers() + self._adapters_disabled = False def get_base_model(self) -> torch.nn.Module: """ diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index 934a3b7928..6c4dc08e51 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -53,6 +53,8 @@ def __init__(self, config, word_embeddings): word_embedding_weights = word_embedding_weights.to(torch.float32) self.embedding.weight = torch.nn.Parameter(word_embedding_weights) + self.embedding.requires_grad_(False) + # Initialize delta embedding with zero weights self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 2e76e13ee1..ef94766c52 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -23,6 +23,7 @@ import torch from torch import nn +from transformers.modeling_layers import GradientCheckpointingLayer from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.tuners_utils import ( @@ -351,13 +352,48 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): # If adapter_names is passed as an argument, we inject it into the forward arguments. adapter_names = kwargs.pop("adapter_names", None) alora_offsets = kwargs.pop("alora_offsets", None) + if adapter_names is None and alora_offsets is None: # nothing to do yield return hook_handles = [] + if alora_offsets is not None: - for layer in self.modules(): + for n, layer in self.named_modules(): + # gradient checkpointing layer are executed concurrently to the 'normal' forward call + # (in the backward step the gradient checkpointing layer's forward will be executed again). + # this means that when the gradient checkpointing layer is called, the _enable_peft_forward_hooks + # context manager is long gone. to be consistent with the normal forward we need to register the pre + # hooks for this concurrent forward call as well. + # + # Note that this will lead to double application of whatever the callbacks do in normal forward. + # Make sure that whatever change is done, can be applied more than once without harm (idempotency). + if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing: + + def forward_pre_hook(name, module, inputs, **kwargs): + for submodule in module.modules(): + if isinstance(submodule, LoraLayer): + handle = submodule.register_forward_pre_hook( + partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]), + with_kwargs=True, + ) + module._peft_gradient_checkpointing_forward_hooks.append(handle) + + def backward_hook(name, module, *grad_output, **kwargs): + while module._peft_gradient_checkpointing_forward_hooks: + module._peft_gradient_checkpointing_forward_hooks.pop().remove() + + if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []): + raise ValueError( + "Multiple invocations of PEFT forward hooks before .backward() with enabled gradient " + "checkpointing. Disable gradient checkpointing or only call forward once per backward." + ) + layer._peft_gradient_checkpointing_forward_hooks = [] + handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets)) + layer._peft_gradient_checkpointing_forward_hooks.append(handle) + handle = layer.register_full_backward_hook(partial(backward_hook, n)) + layer._peft_gradient_checkpointing_forward_hooks.append(handle) if isinstance(layer, LoraLayer): pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets) handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index e30d7fe108..4d59f65553 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1802,8 +1802,11 @@ def test_training_custom_models_layer_indexing(self, test_name, model_id, config pass @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) - def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_custom_models_gradient_checkpointing( + self, test_name, model_id, config_cls, config_kwargs, use_reentrant + ): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 06402d637b..07611dd9cb 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -526,9 +526,12 @@ def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwa @pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant): _skip_if_not_conv1d_supported(model_id, config_cls) - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy()) + self._test_training_gradient_checkpointing( + model_id, config_cls, config_kwargs.copy(), use_reentrant=use_reentrant + ) @pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 1ec0aa0668..4f36ceb211 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -339,8 +339,11 @@ def test_training_encoder_decoders_layer_indexing(self, model_id, config_cls, co @pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_encoder_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_encoder_decoders_gradient_checkpointing( + self, model_id, config_cls, config_kwargs, use_reentrant + ): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index a5377827f4..f11876df6b 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -330,9 +330,10 @@ def test_training_layer_indexing(self, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant): skip_deberta_lora_tests(config_cls, model_id) - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/test_lora_variants.py b/tests/test_lora_variants.py index 41d6a89956..bb93b2988e 100644 --- a/tests/test_lora_variants.py +++ b/tests/test_lora_variants.py @@ -15,8 +15,9 @@ import pytest import torch from torch import nn +from transformers import AutoModelForCausalLM -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora.layer import Conv1d as LoraConv1d from peft.tuners.lora.layer import Conv2d as LoraConv2d from peft.tuners.lora.layer import Embedding as LoraEmbedding @@ -32,6 +33,8 @@ get_alora_offsets_for_generate, ) +from .testing_common import hub_online_once + # Custom model featuring embeddings and a 'visual stack' class CustomModel(nn.Module): @@ -73,6 +76,9 @@ def __init__(self, vocab_size: int = 10, hidden_dim: int = 8): self.embed = nn.Embedding(vocab_size, hidden_dim) self.linear = nn.Linear(hidden_dim, vocab_size) + def prepare_inputs_for_generation(self, *args, **kwargs): + return kwargs + def forward(self, X=None, embeds=None, num_beams=None, alora_offsets=None): if X is not None: embeds = self.embed(X) @@ -181,7 +187,7 @@ class TestActivatedLora: ) # Verify alora_offsets are calculated correctly def test_calculate_alora_offsets(self, input_ids, alora_invocation_tokens, expected_offsets): - config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens) + config = LoraConfig(task_type=TaskType.CAUSAL_LM, alora_invocation_tokens=alora_invocation_tokens) peft_config = {"default": config} # compute offsets @@ -233,7 +239,12 @@ def test_alora_activation_matches_base_until_invocation(self): def test_input_embeds_warning(self): transformers_class = MockTransformerWrapper base_model = transformers_class.from_pretrained() - cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False) + cfg = LoraConfig( + task_type=TaskType.CAUSAL_LM, + target_modules=["linear"], + alora_invocation_tokens=[2], + init_lora_weights=False, + ) lora_model = get_peft_model(base_model, cfg) lora_model.eval() @@ -265,3 +276,41 @@ def test_num_beams_error(self): with torch.no_grad(): lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3]) assert "Beam search not yet supported for aLoRA." in str(e.value) + + def test_gradient_checkpointing_double_forward_raises(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0]) + lora_model = get_peft_model(base_model, cfg) + lora_model.train() + + lora_model.prepare_model_for_gradient_checkpointing(lora_model) + lora_model.gradient_checkpointing_enable() + + inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])} + + lora_model.forward(**inputs) + + with pytest.raises(ValueError, match="Multiple invocations of PEFT forward hooks.*"): + lora_model.forward(**inputs) + + def test_gradient_checkpointing_dpo_doesnt_raise(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0]) + lora_model = get_peft_model(base_model, cfg) + lora_model.train() + + lora_model.prepare_model_for_gradient_checkpointing(lora_model) + lora_model.gradient_checkpointing_enable() + + inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])} + + with lora_model.disable_adapter(): + lora_model.forward(**inputs) + + lora_model.forward(**inputs) diff --git a/tests/testing_common.py b/tests/testing_common.py index dab9ee6e45..f93c00569a 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -25,6 +25,7 @@ import pytest import torch +import transformers import yaml from diffusers import StableDiffusionPipeline from packaging import version @@ -1315,7 +1316,11 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): # more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers assert nb_trainable < nb_trainable_all - def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True): + # Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not + # generate gradients since the adapter is never activated so this will be a no-op for this test. It is still + # a valid test but it might be confusing to see a test pass if it is not supposed to. + if config_cls == PrefixTuningConfig: return pytest.skip(f"Test not applicable for {config_cls}") @@ -1323,17 +1328,28 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa # TODO: no gradients on the "dense" layer, other layers work, not sure why self.skipTest("AdaLora with RoBERTa does not work correctly") + if "bart" in model_id.lower() and version.parse(transformers.__version__) <= version.parse("5.0"): + self.skipTest( + "Bart in transformers < 5.0 doesn't handle input sharing well enough. See transformers#41821" + ) + if (config_cls == OFTConfig) and ("deberta" in model_id.lower()): # TODO: no gradients on the "dense" layer, other layers work, not sure why self.skipTest("OFT with Deberta does not work correctly") + if "gptbigcode" in model_id.lower(): + self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.") + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id) if not getattr(model, "supports_gradient_checkpointing", False): return pytest.skip(f"Model {model_id} does not support gradient checkpointing") - model.gradient_checkpointing_enable() + # Disable lora_dropout and friends to remove non-determinism in gradient creation + for key in list(config_kwargs.keys()): + if key.endswith("dropout"): + del config_kwargs[key] config = config_cls( base_model_name_or_path=model_id, @@ -1341,15 +1357,37 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa ) model = get_peft_model(model, config) model = model.to(self.torch_device) + params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + + # if we don't set this, gradient checkpointing is not activated. + model.train(True) inputs = self.prepare_inputs_for_testing() - # check if `training` works - output = model(**inputs)[0] + # invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing; + # note we're squaring the output for bigger gradients + output = model(**inputs)[0] ** 2 loss = output.sum() loss.backward() + non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0} + + for name, param in params: + param.grad = None + + # invocation with gradient checkpointing for comparison + model.prepare_model_for_gradient_checkpointing(model) + model.gradient_checkpointing_enable({"use_reentrant": use_reentrant}) + + output = model(**inputs)[0] ** 2 + + loss = output.sum() + loss.backward() + + non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0} + assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing + for n, param in model.named_parameters(): if "prompt_encoder." in n: # prompt tuning methods if not issubclass(config_cls, CPTConfig):