Commit 30a19a0
authored
Gradient checkpointing is different to the normal forward/backward process in that it may execute forward
steps in the backward process. It is therefore decoupled from the normal forward process. Since we have
methods, such as activated LoRA, that depend on `peft_forward_hooks` in normal forward we have to make sure
that these hooks are applied in the independently executed forward by gradient checkpointing as well.
To this end we had several options:
- Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then
we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing
a non-general way which might biting us in the future since more methods need parameter injection.
- Don't support `use_reentrant=True` (default for transformers) and use `context_fn` parameter to inject parameters
when using `use_reentrant=False`. `torch.utils.checkpoint` supports adding a `context_fn` parameter which returns
two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to
inject the variables into the module. In practice we would need to modify the `keywords` argument of every
`._gradient_checkpointing_func` attribute of every module to inject the `context_fn` callback and update those
accordingly every forward call. This is far less reliable than forward hooks.
- Register forward hooks on the gradient checkpointing layers that apply the same hooks that `enable_peft_forward_hooks`
does - but decoupled from `enable_peft_forward_hooks`. These hooks are removed once a full backward hook on the
gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that
we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks
in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this
adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls
in succession before a backward call (the user can do that with gradient checkpointing disabled).
In this change I'm implementing the last option, forward hooks on gradient checkpointing layers.
We already had tests but there were some issues that are improved in this PR:
- parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default;
since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover
- check consistency between normal and checkpointed model runs, both runs must have the same set
of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero
as well (gradient checkpointing can fail with zero grads)
- set `model.train` to enable gradient checkpointing
- disable `lora_dropout` if set to make gradients (a bit more) deterministic
While testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so,
skipping for now until fixed. BART had an issue which was caused by the model not sharing the embed tokens module but just the weights. This in turn leads to `get_input_embeddings` returning one specific module which may not be invoked in the forward call at all (`model.shared` in this case - it has the same weights but it is a different `nn.Module`). Since `enable_input_require_grads` depends on the module returned by `get_input_embeddings` to have working forward hooks, the preparation for gradient checkpointing fails. This needs to be fixed in transformers either by targeting all tied weights or sharing the embedding *modules* instead of just the weights (like in T5).
* Remove gradient requirement in CPT's embedding
This module is never called in a gradient path so it is safe to set it
to `requires_grad=False`. This helps in upholding the assumption that
all parameters that require a gradient need to receive one.
* Improve gradient checkpointing tests
- parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default;
since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover
- check consistency between normal and checkpointed model runs, both runs must have the same set
of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero
as well (gradient checkpointing can fail with zero grads)
- set `model.train` to enable gradient checkpointing
- disable `lora_dropout` if set to make gradients (a bit more) deterministic
Also while testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so,
skipping for now until fixed. BART doesn't work with the newest changes and fails at `loss.backward()`
with `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`).
This is still open.
* Add support for DPO training
In DPO training we have the case that we (potentially) do a forward call first with `disable_adapter`,
then do a forward run with adapters enabled and afterwards do DPO stuff. This is problematic since
we will add gradient checkpointing forward hooks, as long as we detect alora offsets, and the forward
hook code enforces that we only do (forward, backward) calls, (forward, forward) is forbidden since
it would mean we'd register a second pair of hooks for gradient checkpointing.
1 parent d43b315 commit 30a19a0
File tree
9 files changed
+171
-17
lines changed- src/peft
- tuners
- cpt
- lora
- tests
9 files changed
+171
-17
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
142 | 142 | | |
143 | 143 | | |
144 | 144 | | |
| 145 | + | |
| 146 | + | |
145 | 147 | | |
146 | 148 | | |
147 | 149 | | |
| |||
167 | 169 | | |
168 | 170 | | |
169 | 171 | | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
170 | 183 | | |
171 | 184 | | |
172 | 185 | | |
| |||
890 | 903 | | |
891 | 904 | | |
892 | 905 | | |
893 | | - | |
| 906 | + | |
894 | 907 | | |
895 | 908 | | |
896 | 909 | | |
| |||
940 | 953 | | |
941 | 954 | | |
942 | 955 | | |
| 956 | + | |
943 | 957 | | |
944 | 958 | | |
945 | 959 | | |
946 | 960 | | |
| 961 | + | |
947 | 962 | | |
948 | 963 | | |
949 | 964 | | |
950 | 965 | | |
| 966 | + | |
951 | 967 | | |
952 | 968 | | |
953 | 969 | | |
| 970 | + | |
954 | 971 | | |
955 | 972 | | |
956 | 973 | | |
| |||
962 | 979 | | |
963 | 980 | | |
964 | 981 | | |
| 982 | + | |
965 | 983 | | |
966 | 984 | | |
967 | 985 | | |
968 | 986 | | |
969 | 987 | | |
| 988 | + | |
970 | 989 | | |
971 | 990 | | |
972 | 991 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| 56 | + | |
| 57 | + | |
56 | 58 | | |
57 | 59 | | |
58 | 60 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| |||
351 | 352 | | |
352 | 353 | | |
353 | 354 | | |
| 355 | + | |
354 | 356 | | |
355 | 357 | | |
356 | 358 | | |
357 | 359 | | |
358 | 360 | | |
| 361 | + | |
359 | 362 | | |
360 | | - | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
361 | 397 | | |
362 | 398 | | |
363 | 399 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1793 | 1793 | | |
1794 | 1794 | | |
1795 | 1795 | | |
1796 | | - | |
1797 | | - | |
| 1796 | + | |
| 1797 | + | |
| 1798 | + | |
| 1799 | + | |
| 1800 | + | |
1798 | 1801 | | |
1799 | 1802 | | |
1800 | 1803 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
541 | 541 | | |
542 | 542 | | |
543 | 543 | | |
544 | | - | |
| 544 | + | |
| 545 | + | |
545 | 546 | | |
546 | | - | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
547 | 550 | | |
548 | 551 | | |
549 | 552 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
353 | 353 | | |
354 | 354 | | |
355 | 355 | | |
356 | | - | |
357 | | - | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
358 | 361 | | |
359 | 362 | | |
360 | 363 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
330 | 330 | | |
331 | 331 | | |
332 | 332 | | |
333 | | - | |
| 333 | + | |
| 334 | + | |
334 | 335 | | |
335 | | - | |
| 336 | + | |
336 | 337 | | |
337 | 338 | | |
338 | 339 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | | - | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| 36 | + | |
| 37 | + | |
35 | 38 | | |
36 | 39 | | |
37 | 40 | | |
| |||
73 | 76 | | |
74 | 77 | | |
75 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
76 | 82 | | |
77 | 83 | | |
78 | 84 | | |
| |||
181 | 187 | | |
182 | 188 | | |
183 | 189 | | |
184 | | - | |
| 190 | + | |
185 | 191 | | |
186 | 192 | | |
187 | 193 | | |
| |||
233 | 239 | | |
234 | 240 | | |
235 | 241 | | |
236 | | - | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
237 | 248 | | |
238 | 249 | | |
239 | 250 | | |
| |||
265 | 276 | | |
266 | 277 | | |
267 | 278 | | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| 28 | + | |
28 | 29 | | |
29 | 30 | | |
30 | 31 | | |
| |||
1343 | 1344 | | |
1344 | 1345 | | |
1345 | 1346 | | |
1346 | | - | |
| 1347 | + | |
| 1348 | + | |
| 1349 | + | |
| 1350 | + | |
| 1351 | + | |
1347 | 1352 | | |
1348 | 1353 | | |
1349 | 1354 | | |
1350 | 1355 | | |
1351 | 1356 | | |
1352 | 1357 | | |
1353 | 1358 | | |
| 1359 | + | |
| 1360 | + | |
| 1361 | + | |
| 1362 | + | |
| 1363 | + | |
1354 | 1364 | | |
1355 | 1365 | | |
1356 | 1366 | | |
1357 | 1367 | | |
| 1368 | + | |
| 1369 | + | |
| 1370 | + | |
1358 | 1371 | | |
1359 | 1372 | | |
1360 | 1373 | | |
1361 | 1374 | | |
1362 | 1375 | | |
1363 | 1376 | | |
1364 | | - | |
| 1377 | + | |
| 1378 | + | |
| 1379 | + | |
| 1380 | + | |
1365 | 1381 | | |
1366 | 1382 | | |
1367 | 1383 | | |
1368 | 1384 | | |
1369 | 1385 | | |
1370 | 1386 | | |
1371 | 1387 | | |
| 1388 | + | |
| 1389 | + | |
| 1390 | + | |
| 1391 | + | |
1372 | 1392 | | |
1373 | 1393 | | |
1374 | 1394 | | |
1375 | | - | |
1376 | | - | |
| 1395 | + | |
| 1396 | + | |
| 1397 | + | |
1377 | 1398 | | |
1378 | 1399 | | |
1379 | 1400 | | |
1380 | 1401 | | |
| 1402 | + | |
| 1403 | + | |
| 1404 | + | |
| 1405 | + | |
| 1406 | + | |
| 1407 | + | |
| 1408 | + | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
| 1412 | + | |
| 1413 | + | |
| 1414 | + | |
| 1415 | + | |
| 1416 | + | |
| 1417 | + | |
| 1418 | + | |
1381 | 1419 | | |
1382 | 1420 | | |
1383 | 1421 | | |
| |||
0 commit comments