Skip to content

Commit 30a19a0

Browse files
authored
Fix #2826: implement gradient checkpoint callbacks (#2860)
Gradient checkpointing is different to the normal forward/backward process in that it may execute forward steps in the backward process. It is therefore decoupled from the normal forward process. Since we have methods, such as activated LoRA, that depend on `peft_forward_hooks` in normal forward we have to make sure that these hooks are applied in the independently executed forward by gradient checkpointing as well. To this end we had several options: - Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing a non-general way which might biting us in the future since more methods need parameter injection. - Don't support `use_reentrant=True` (default for transformers) and use `context_fn` parameter to inject parameters when using `use_reentrant=False`. `torch.utils.checkpoint` supports adding a `context_fn` parameter which returns two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to inject the variables into the module. In practice we would need to modify the `keywords` argument of every `._gradient_checkpointing_func` attribute of every module to inject the `context_fn` callback and update those accordingly every forward call. This is far less reliable than forward hooks. - Register forward hooks on the gradient checkpointing layers that apply the same hooks that `enable_peft_forward_hooks` does - but decoupled from `enable_peft_forward_hooks`. These hooks are removed once a full backward hook on the gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls in succession before a backward call (the user can do that with gradient checkpointing disabled). In this change I'm implementing the last option, forward hooks on gradient checkpointing layers. We already had tests but there were some issues that are improved in this PR: - parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default; since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover - check consistency between normal and checkpointed model runs, both runs must have the same set of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero as well (gradient checkpointing can fail with zero grads) - set `model.train` to enable gradient checkpointing - disable `lora_dropout` if set to make gradients (a bit more) deterministic While testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so, skipping for now until fixed. BART had an issue which was caused by the model not sharing the embed tokens module but just the weights. This in turn leads to `get_input_embeddings` returning one specific module which may not be invoked in the forward call at all (`model.shared` in this case - it has the same weights but it is a different `nn.Module`). Since `enable_input_require_grads` depends on the module returned by `get_input_embeddings` to have working forward hooks, the preparation for gradient checkpointing fails. This needs to be fixed in transformers either by targeting all tied weights or sharing the embedding *modules* instead of just the weights (like in T5). * Remove gradient requirement in CPT's embedding This module is never called in a gradient path so it is safe to set it to `requires_grad=False`. This helps in upholding the assumption that all parameters that require a gradient need to receive one. * Improve gradient checkpointing tests - parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default; since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover - check consistency between normal and checkpointed model runs, both runs must have the same set of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero as well (gradient checkpointing can fail with zero grads) - set `model.train` to enable gradient checkpointing - disable `lora_dropout` if set to make gradients (a bit more) deterministic Also while testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so, skipping for now until fixed. BART doesn't work with the newest changes and fails at `loss.backward()` with `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`). This is still open. * Add support for DPO training In DPO training we have the case that we (potentially) do a forward call first with `disable_adapter`, then do a forward run with adapters enabled and afterwards do DPO stuff. This is problematic since we will add gradient checkpointing forward hooks, as long as we detect alora offsets, and the forward hook code enforces that we only do (forward, backward) calls, (forward, forward) is forbidden since it would mean we'd register a second pair of hooks for gradient checkpointing.
1 parent d43b315 commit 30a19a0

File tree

9 files changed

+171
-17
lines changed

9 files changed

+171
-17
lines changed

src/peft/peft_model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def __init__(
142142
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
143143
self.base_model.config.pretraining_tp = 1
144144

145+
self._adapters_disabled = False
146+
145147
@property
146148
def peft_config(self) -> dict[str, PeftConfig]:
147149
if self._is_prompt_learning:
@@ -167,6 +169,17 @@ def active_adapters(self) -> list[str]:
167169
adapters = [adapters]
168170
return adapters
169171

172+
@property
173+
def has_active_enabled_adapter(self) -> bool:
174+
"""Reflects whether the adapters are purposefully disabled (via disable_adapter) or if there
175+
are no active adapters (enabled but inactive). They are two separate mechanisms but sometimes it is helpful to
176+
know whether the model has any active/enabled adapter at all.
177+
"""
178+
if self.peft_config[self.active_adapter].is_prompt_learning:
179+
return not self._adapters_disabled
180+
181+
return not self._adapters_disabled or not self.active_adapters
182+
170183
@peft_config.setter
171184
def peft_config(self, value: dict[str, PeftConfig]):
172185
if self._is_prompt_learning:
@@ -890,7 +903,7 @@ def __getattr__(self, name: str):
890903
def _enable_peft_forward_hooks(self, *args, **kwargs):
891904
# If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this
892905
# runs without any changes
893-
if hasattr(self.base_model, "_enable_peft_forward_hooks"):
906+
if hasattr(self.base_model, "_enable_peft_forward_hooks") and self.has_active_enabled_adapter:
894907
with self.base_model._enable_peft_forward_hooks(*args, **kwargs):
895908
yield
896909
return
@@ -940,17 +953,21 @@ def disable_adapter(self):
940953
self.forward = self.base_model.forward
941954
old_prepare_inputs_for_generation = self.prepare_inputs_for_generation
942955
self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
956+
self._adapters_disabled = True
943957
yield
944958
finally:
945959
self.forward = old_forward
946960
self.prepare_inputs_for_generation = old_prepare_inputs_for_generation
961+
self._adapters_disabled = False
947962

948963
elif self.peft_config[self.active_adapter].is_adaption_prompt:
949964
try:
950965
self.base_model.disable_adapter_layers()
966+
self._adapters_disabled = True
951967
yield
952968
finally:
953969
self.base_model.enable_adapter_layers()
970+
self._adapters_disabled = False
954971

955972
else: # LoRA, LoHa, etc.
956973
model_status = self.get_model_status()
@@ -962,11 +979,13 @@ def disable_adapter(self):
962979
)
963980
try:
964981
self.base_model.disable_adapter_layers()
982+
self._adapters_disabled = True
965983
yield
966984
finally:
967985
if model_status.enabled is not False:
968986
# model_status.enabled is `True` or `"irregular"`
969987
self.base_model.enable_adapter_layers()
988+
self._adapters_disabled = False
970989

971990
def get_base_model(self) -> torch.nn.Module:
972991
"""

src/peft/tuners/cpt/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self, config, word_embeddings):
5353
word_embedding_weights = word_embedding_weights.to(torch.float32)
5454
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
5555

56+
self.embedding.requires_grad_(False)
57+
5658
# Initialize delta embedding with zero weights
5759
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
5860
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32)

src/peft/tuners/lora/model.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import torch
2525
from torch import nn
26+
from transformers.modeling_layers import GradientCheckpointingLayer
2627

2728
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
2829
from peft.tuners.tuners_utils import (
@@ -351,13 +352,48 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
351352
# If adapter_names is passed as an argument, we inject it into the forward arguments.
352353
adapter_names = kwargs.pop("adapter_names", None)
353354
alora_offsets = kwargs.pop("alora_offsets", None)
355+
354356
if adapter_names is None and alora_offsets is None:
355357
# nothing to do
356358
yield
357359
return
358360
hook_handles = []
361+
359362
if alora_offsets is not None:
360-
for layer in self.modules():
363+
for n, layer in self.named_modules():
364+
# gradient checkpointing layer are executed concurrently to the 'normal' forward call
365+
# (in the backward step the gradient checkpointing layer's forward will be executed again).
366+
# this means that when the gradient checkpointing layer is called, the _enable_peft_forward_hooks
367+
# context manager is long gone. to be consistent with the normal forward we need to register the pre
368+
# hooks for this concurrent forward call as well.
369+
#
370+
# Note that this will lead to double application of whatever the callbacks do in normal forward.
371+
# Make sure that whatever change is done, can be applied more than once without harm (idempotency).
372+
if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing:
373+
374+
def forward_pre_hook(name, module, inputs, **kwargs):
375+
for submodule in module.modules():
376+
if isinstance(submodule, LoraLayer):
377+
handle = submodule.register_forward_pre_hook(
378+
partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]),
379+
with_kwargs=True,
380+
)
381+
module._peft_gradient_checkpointing_forward_hooks.append(handle)
382+
383+
def backward_hook(name, module, *grad_output, **kwargs):
384+
while module._peft_gradient_checkpointing_forward_hooks:
385+
module._peft_gradient_checkpointing_forward_hooks.pop().remove()
386+
387+
if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []):
388+
raise ValueError(
389+
"Multiple invocations of PEFT forward hooks before .backward() with enabled gradient "
390+
"checkpointing. Disable gradient checkpointing or only call forward once per backward."
391+
)
392+
layer._peft_gradient_checkpointing_forward_hooks = []
393+
handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets))
394+
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
395+
handle = layer.register_full_backward_hook(partial(backward_hook, n))
396+
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
361397
if isinstance(layer, LoraLayer):
362398
pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets)
363399
handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True)

tests/test_custom_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,8 +1793,11 @@ def test_training_custom_models_layer_indexing(self, test_name, model_id, config
17931793
pass
17941794

17951795
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
1796-
def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
1797-
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
1796+
@pytest.mark.parametrize("use_reentrant", [True, False])
1797+
def test_training_custom_models_gradient_checkpointing(
1798+
self, test_name, model_id, config_cls, config_kwargs, use_reentrant
1799+
):
1800+
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)
17981801

17991802
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
18001803
def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):

tests/test_decoder_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,12 @@ def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwa
541541

542542
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
543543
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
544-
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
544+
@pytest.mark.parametrize("use_reentrant", [True, False])
545+
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
545546
_skip_if_not_conv1d_supported(model_id, config_cls)
546-
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy())
547+
self._test_training_gradient_checkpointing(
548+
model_id, config_cls, config_kwargs.copy(), use_reentrant=use_reentrant
549+
)
547550

548551
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
549552
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)

tests/test_encoder_decoder_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,11 @@ def test_training_encoder_decoders_layer_indexing(self, model_id, config_cls, co
353353

354354
@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
355355
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
356-
def test_training_encoder_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
357-
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
356+
@pytest.mark.parametrize("use_reentrant", [True, False])
357+
def test_training_encoder_decoders_gradient_checkpointing(
358+
self, model_id, config_cls, config_kwargs, use_reentrant
359+
):
360+
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)
358361

359362
@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
360363
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)

tests/test_feature_extraction_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ def test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
330330

331331
@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
332332
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
333-
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
333+
@pytest.mark.parametrize("use_reentrant", [True, False])
334+
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
334335
skip_deberta_lora_tests(config_cls, model_id)
335-
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
336+
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)
336337

337338
@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
338339
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)

tests/test_lora_variants.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import pytest
1616
import torch
1717
from torch import nn
18+
from transformers import AutoModelForCausalLM
1819

19-
from peft import LoraConfig, get_peft_model
20+
from peft import LoraConfig, TaskType, get_peft_model
2021
from peft.tuners.lora.layer import Conv1d as LoraConv1d
2122
from peft.tuners.lora.layer import Conv2d as LoraConv2d
2223
from peft.tuners.lora.layer import Embedding as LoraEmbedding
@@ -32,6 +33,8 @@
3233
get_alora_offsets_for_generate,
3334
)
3435

36+
from .testing_common import hub_online_once
37+
3538

3639
# Custom model featuring embeddings and a 'visual stack'
3740
class CustomModel(nn.Module):
@@ -73,6 +76,9 @@ def __init__(self, vocab_size: int = 10, hidden_dim: int = 8):
7376
self.embed = nn.Embedding(vocab_size, hidden_dim)
7477
self.linear = nn.Linear(hidden_dim, vocab_size)
7578

79+
def prepare_inputs_for_generation(self, *args, **kwargs):
80+
return kwargs
81+
7682
def forward(self, X=None, embeds=None, num_beams=None, alora_offsets=None):
7783
if X is not None:
7884
embeds = self.embed(X)
@@ -181,7 +187,7 @@ class TestActivatedLora:
181187
)
182188
# Verify alora_offsets are calculated correctly
183189
def test_calculate_alora_offsets(self, input_ids, alora_invocation_tokens, expected_offsets):
184-
config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens)
190+
config = LoraConfig(task_type=TaskType.CAUSAL_LM, alora_invocation_tokens=alora_invocation_tokens)
185191
peft_config = {"default": config}
186192

187193
# compute offsets
@@ -233,7 +239,12 @@ def test_alora_activation_matches_base_until_invocation(self):
233239
def test_input_embeds_warning(self):
234240
transformers_class = MockTransformerWrapper
235241
base_model = transformers_class.from_pretrained()
236-
cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False)
242+
cfg = LoraConfig(
243+
task_type=TaskType.CAUSAL_LM,
244+
target_modules=["linear"],
245+
alora_invocation_tokens=[2],
246+
init_lora_weights=False,
247+
)
237248
lora_model = get_peft_model(base_model, cfg)
238249
lora_model.eval()
239250

@@ -265,3 +276,41 @@ def test_num_beams_error(self):
265276
with torch.no_grad():
266277
lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3])
267278
assert "Beam search not yet supported for aLoRA." in str(e.value)
279+
280+
def test_gradient_checkpointing_double_forward_raises(self):
281+
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
282+
283+
with hub_online_once(model_id):
284+
base_model = AutoModelForCausalLM.from_pretrained(model_id)
285+
cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0])
286+
lora_model = get_peft_model(base_model, cfg)
287+
lora_model.train()
288+
289+
lora_model.prepare_model_for_gradient_checkpointing(lora_model)
290+
lora_model.gradient_checkpointing_enable()
291+
292+
inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])}
293+
294+
lora_model.forward(**inputs)
295+
296+
with pytest.raises(ValueError, match="Multiple invocations of PEFT forward hooks.*"):
297+
lora_model.forward(**inputs)
298+
299+
def test_gradient_checkpointing_dpo_doesnt_raise(self):
300+
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
301+
302+
with hub_online_once(model_id):
303+
base_model = AutoModelForCausalLM.from_pretrained(model_id)
304+
cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear", alora_invocation_tokens=[0])
305+
lora_model = get_peft_model(base_model, cfg)
306+
lora_model.train()
307+
308+
lora_model.prepare_model_for_gradient_checkpointing(lora_model)
309+
lora_model.gradient_checkpointing_enable()
310+
311+
inputs = {"input_ids": torch.tensor([[0, 1, 2, 3]])}
312+
313+
with lora_model.disable_adapter():
314+
lora_model.forward(**inputs)
315+
316+
lora_model.forward(**inputs)

tests/testing_common.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import pytest
2727
import torch
28+
import transformers
2829
import yaml
2930
from diffusers import StableDiffusionPipeline
3031
from packaging import version
@@ -1343,41 +1344,78 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
13431344
# more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers
13441345
assert nb_trainable < nb_trainable_all
13451346

1346-
def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
1347+
def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True):
1348+
# Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not
1349+
# generate gradients since the adapter is never activated so this will be a no-op for this test. It is still
1350+
# a valid test but it might be confusing to see a test pass if it is not supposed to.
1351+
13471352
if config_cls == PrefixTuningConfig:
13481353
return pytest.skip(f"Test not applicable for {config_cls}")
13491354

13501355
if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
13511356
# TODO: no gradients on the "dense" layer, other layers work, not sure why
13521357
self.skipTest("AdaLora with RoBERTa does not work correctly")
13531358

1359+
if "bart" in model_id.lower() and version.parse(transformers.__version__) <= version.parse("5.0"):
1360+
self.skipTest(
1361+
"Bart in transformers < 5.0 doesn't handle input sharing well enough. See transformers#41821"
1362+
)
1363+
13541364
if (config_cls == OFTConfig) and ("deberta" in model_id.lower()):
13551365
# TODO: no gradients on the "dense" layer, other layers work, not sure why
13561366
self.skipTest("OFT with Deberta does not work correctly")
13571367

1368+
if "gptbigcode" in model_id.lower():
1369+
self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.")
1370+
13581371
with hub_online_once(model_id):
13591372
model = self.transformers_class.from_pretrained(model_id)
13601373

13611374
if not getattr(model, "supports_gradient_checkpointing", False):
13621375
return pytest.skip(f"Model {model_id} does not support gradient checkpointing")
13631376

1364-
model.gradient_checkpointing_enable()
1377+
# Disable lora_dropout and friends to remove non-determinism in gradient creation
1378+
for key in list(config_kwargs.keys()):
1379+
if key.endswith("dropout"):
1380+
del config_kwargs[key]
13651381

13661382
config = config_cls(
13671383
base_model_name_or_path=model_id,
13681384
**config_kwargs,
13691385
)
13701386
model = get_peft_model(model, config)
13711387
model = model.to(self.torch_device)
1388+
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
1389+
1390+
# if we don't set this, gradient checkpointing is not activated.
1391+
model.train(True)
13721392

13731393
inputs = self.prepare_inputs_for_testing()
13741394

1375-
# check if `training` works
1376-
output = model(**inputs)[0]
1395+
# invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing;
1396+
# note we're squaring the output for bigger gradients
1397+
output = model(**inputs)[0] ** 2
13771398

13781399
loss = output.sum()
13791400
loss.backward()
13801401

1402+
non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0}
1403+
1404+
for name, param in params:
1405+
param.grad = None
1406+
1407+
# invocation with gradient checkpointing for comparison
1408+
model.prepare_model_for_gradient_checkpointing(model)
1409+
model.gradient_checkpointing_enable({"use_reentrant": use_reentrant})
1410+
1411+
output = model(**inputs)[0] ** 2
1412+
1413+
loss = output.sum()
1414+
loss.backward()
1415+
1416+
non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0}
1417+
assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing
1418+
13811419
for n, param in model.named_parameters():
13821420
if "prompt_encoder." in n: # prompt tuning methods
13831421
if not issubclass(config_cls, CPTConfig):

0 commit comments

Comments
 (0)