You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
|[truegrad.nn](#nn)| * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20
+
|[truegrad.utils.patch_torch](#patch-torch)| * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21
+
|[backpack](#backpack)| * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22
+
|[truegrad.utils.patch_model](#patch-custom-models)| * Best compatibility | * Fails silently on fused functions<br/>* More costly than truegrad.nn |
23
+
24
+
Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
25
+
partial application of TrueGrad.
26
+
27
+
### nn
28
+
29
+
The preferred method of using TrueGrad is by replacing `torch.nn` with performant `truegrad.nn` modules. While other
30
+
methods add compute and memory overheads, `truegrad.nn` and `truegrad.nn.functional` have hand-crafted gradients. This
31
+
is the most powerful method, although it requires code modifications.
32
+
33
+
```PYTHON
34
+
import torch
35
+
from truegrad import nn
36
+
from truegrad.optim import TGAdamW
37
+
38
+
# define model by mixing truegrad.nn and torch.nn
39
+
model = torch.nn.Sequential(nn.Linear(1, 10),
40
+
nn.LayerNorm([1, 10]),
41
+
torch.nn.ReLU(),
42
+
nn.Linear(10, 1))
43
+
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
44
+
45
+
# standard training loop
46
+
whileTrue:
47
+
input= torch.randn((16, 1))
48
+
model(input).mean().backward()
49
+
optim.step()
50
+
```
51
+
52
+
### Patch Torch
53
+
54
+
In some cases, you can't modify the model's source. For example, when importing models from `torchvision`. If that's the
55
+
case, or if you simply want to try out TrueGrad, you can use `truegrad.utils.patch_torch()`, to
56
+
replace `torch.nn.Module`'s with `truegrad.nn.Module`'s where possible. For example, the code below can be used to train
57
+
a ResNet-18:
58
+
59
+
```PYTHON
60
+
import torch
61
+
from torchvision.models import resnet18
62
+
63
+
from truegrad.optim import TGAdamW
64
+
from truegrad.utils import patch_torch
65
+
66
+
patch_torch() # call before model creation, otherwise complete freedom
0 commit comments