Skip to content

Commit 6fae8fd

Browse files
authored
Merge pull request #2 from ClashLuke/functional
Functional
2 parents 0a14e38 + 508906e commit 6fae8fd

File tree

7 files changed

+1774
-275
lines changed

7 files changed

+1774
-275
lines changed

README.md

Lines changed: 86 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,81 @@ python3 -m pip install truegrad
1212

1313
## Examples
1414

15+
TrueGrad supports various backends, each with their own tradeoffs:
16+
17+
| Name | Advantages | Disadvantages |
18+
|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
19+
| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20+
| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21+
| [backpack](#backpack) | * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22+
| [truegrad.utils.patch_model](#patch-custom-models) | * Best compatibility | * Fails silently on fused functions<br/>* More costly than truegrad.nn |
23+
24+
Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
25+
partial application of TrueGrad.
26+
27+
### nn
28+
29+
The preferred method of using TrueGrad is by replacing `torch.nn` with performant `truegrad.nn` modules. While other
30+
methods add compute and memory overheads, `truegrad.nn` and `truegrad.nn.functional` have hand-crafted gradients. This
31+
is the most powerful method, although it requires code modifications.
32+
33+
```PYTHON
34+
import torch
35+
from truegrad import nn
36+
from truegrad.optim import TGAdamW
37+
38+
# define model by mixing truegrad.nn and torch.nn
39+
model = torch.nn.Sequential(nn.Linear(1, 10),
40+
nn.LayerNorm([1, 10]),
41+
torch.nn.ReLU(),
42+
nn.Linear(10, 1))
43+
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
44+
45+
# standard training loop
46+
while True:
47+
input = torch.randn((16, 1))
48+
model(input).mean().backward()
49+
optim.step()
50+
```
51+
52+
### Patch Torch
53+
54+
In some cases, you can't modify the model's source. For example, when importing models from `torchvision`. If that's the
55+
case, or if you simply want to try out TrueGrad, you can use `truegrad.utils.patch_torch()`, to
56+
replace `torch.nn.Module`'s with `truegrad.nn.Module`'s where possible. For example, the code below can be used to train
57+
a ResNet-18:
58+
59+
```PYTHON
60+
import torch
61+
from torchvision.models import resnet18
62+
63+
from truegrad.optim import TGAdamW
64+
from truegrad.utils import patch_torch
65+
66+
patch_torch() # call before model creation, otherwise complete freedom
67+
model = resnet18().cuda()
68+
optim = TGAdamW(model.parameters(), lr=1e-7, weight_decay=0)
69+
70+
# constant input/output to overfit
71+
inp = torch.randn((2, 3, 224, 224)).cuda()
72+
tgt = torch.randint(0, 1000, (2,)).cuda()
73+
74+
# standard training loop
75+
i = 0
76+
while True:
77+
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
78+
loss.backward()
79+
optim.step()
80+
i += 1
81+
if i % 5 == 0:
82+
print(i, loss.item())
83+
```
84+
1585
### BackPack
1686

17-
The preferred method to integrate TrueGrad is using [BackPack](https://github.com/f-dangel/backpack). BackPack is a
18-
third-party library that automatically computes the sum of gradient squares and works for most models by implementing
19-
custom backward rules for many `torch.nn.Module`'s.
87+
The most stable although also memory hungry method to compute TrueGrad statistics is to use
88+
[BackPack](https://github.com/f-dangel/backpack). BackPack is a third-party library that automatically computes the sum
89+
of gradient squares and works for most models by implementing custom backward rules for many `torch.nn.Module`'s.
2090

2191
```PYTHON
2292
import backpack
@@ -25,10 +95,10 @@ from torch.nn import CrossEntropyLoss
2595
from truegrad.optim import TGAdamW
2696
from torchvision.models import alexnet
2797

28-
model = alexnet()
98+
model = alexnet() # BatchNorm and in-place ops (like ResNet's residual path) aren't supported
2999
optim = TGAdamW(model.parameters(), lr=1e-7, weight_decay=0)
30100

31-
# backpack can't handle inplace ops like nn.ReLU(inplace=True) and `x += y`
101+
# replace inplace ops like nn.ReLU(inplace=True) where possible
32102
for mod in model.modules():
33103
if hasattr(mod, "inplace"):
34104
mod.inplace = False
@@ -62,12 +132,13 @@ your model has any layer called `.output` or you're using PyTorch >= 1.13, you w
62132
### Patch Custom Models
63133

64134
Another option to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`.
65-
`patch_model()` will go through all`torch.nn.Module`'s in PyTorch model and convert their `torch.nn.Parameter`'s to
135+
`patch_model()` will go through all `torch.nn.Module`'s in PyTorch model and convert their `torch.nn.Parameter`'s to
66136
`truegrad.nn.TrueGradParameter`'s. A `TrueGradParameter` acts largely the same as a `torch.nn.Parameter`, but adds
67-
required operations into the model's backward pass.\
137+
required operations into the model's backward pass. Note that this doesn't give the most effective computation graph,
138+
but works well for many custom models.\
68139
Importantly, be aware that this does not work for fused functions, such as `torch.nn.LayerNorm`
69-
and `torch.nn.MultiheadAttention`. However, unfused functions which directly access a parameter, such as multiplication
70-
and work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
140+
and `torch.nn.MultiheadAttention`. However, unfused functions which directly access a parameter, such as multiplication,
141+
work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
71142

72143
```PYTHON
73144
import transformers
@@ -87,35 +158,6 @@ for sample in ["Hello", "World", "!"]:
87158
optim.step()
88159
```
89160

90-
### nn
91-
92-
Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation or even fail
93-
unexpectedly. That's why a pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the
94-
`truegrad.utils` module. Compared to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower
95-
memory usage, although it might require code alterations and doesn't support all models. You cannot (currently) use
96-
`truegrad.nn` with `truegrad.utils`, as both use different ways to arrive at the same value. However, you can
97-
combine `torch.nn.Modules` and `truegrad.nn.Modules` and use the truegrad information only where it is available (
98-
see [Partial TrueGrad](#Partial-TrueGrad)).
99-
100-
```PYTHON
101-
import torch
102-
from truegrad import nn
103-
from truegrad.optim import TGAdamW
104-
105-
# define model by mixing truegrad.nn and torch.nn
106-
model = torch.nn.Sequential(nn.Linear(1, 10),
107-
nn.LayerNorm([1, 10]),
108-
torch.nn.ReLU(),
109-
nn.Linear(10, 1))
110-
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
111-
112-
# standard training loop
113-
while True:
114-
input = torch.randn((16, 1))
115-
model(input).mean().backward()
116-
optim.step()
117-
```
118-
119161
### Partial TrueGrad
120162

121163
Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow, and sometimes it's
@@ -138,8 +180,13 @@ model = torch.nn.Sequential(nn.Linear(1, 10), # Weights coming from truegrad.nn
138180
optim = TGAdamW(model.parameters(), default_to_adam=True)
139181

140182
# standard training loop
183+
i = 0
141184
while True:
142185
input = torch.randn((16, 1))
143-
model(input).mean().backward()
186+
loss = model(input).mean()
187+
loss.backward()
144188
optim.step()
189+
i += 1
190+
if i % 5 == 0:
191+
print(i, loss.item())
145192
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='0.1.0',
13+
version='1.0.0',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/functional.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,61 @@
1-
from typing import List, Tuple
1+
from typing import Any, Callable, List, Tuple
22

33
import torch
4+
from torch.utils._pytree import tree_map
5+
6+
7+
def _unpack(x: Any) -> Any:
8+
if isinstance(x, TrueGradTensor):
9+
return x.data
10+
return x
11+
12+
13+
_base_torch_function = torch.Tensor.__torch_function__
14+
15+
16+
class TrueGradTensor(torch.Tensor):
17+
sum_grad_squared: torch.Tensor
18+
data: torch.Tensor
19+
requires_grad: bool
20+
21+
__slots__ = ['sum_grad_squared', "data", "requires_grad"]
22+
23+
@staticmethod
24+
def __new__(cls, data: torch.Tensor):
25+
meta = data.new_empty((0,))
26+
meta.set_(meta.storage(), 0, data.size(), data.stride())
27+
r = torch.Tensor._make_subclass(cls, meta, data.requires_grad)
28+
r.data = data
29+
r.sum_grad_squared = None
30+
r.activated = False
31+
r.requires_grad = data.requires_grad
32+
return r
33+
34+
def __repr__(self):
35+
return f"TrueGradTensor({self.data})"
36+
37+
@classmethod
38+
def __torch_function__(cls, func, types, args=(), kwargs=None):
39+
if kwargs is None:
40+
kwargs = {}
41+
out = _base_torch_function(func, [], tree_map(_unpack, args), tree_map(_unpack, kwargs))
42+
return out
443

544

645
class MulFn(torch.autograd.Function):
746
@staticmethod
847
def forward(ctx, inp: torch.Tensor, weight: torch.Tensor):
948
if weight.requires_grad:
10-
ctx.save_for_backward(inp, weight)
49+
ctx.save_for_backward(inp)
50+
ctx.weight = weight
1151
return inp * weight
1252

1353
@staticmethod
1454
def backward(ctx, dy: torch.Tensor):
1555
if not ctx.saved_tensors:
1656
return None, None
17-
inp, weight = ctx.saved_tensors
57+
inp, = ctx.saved_tensors
58+
weight = ctx.weight
1859
diff = inp.ndim - weight.ndim
1960
summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
2061
weight_grad = dy * inp
@@ -33,14 +74,15 @@ def forward(ctx, inp: torch.Tensor, weight: torch.Tensor):
3374
diff = inp.ndim - weight.ndim
3475
ctx.summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
3576
ctx.batch_size = inp.size(0)
36-
ctx.save_for_backward(weight)
77+
ctx.weight = weight
78+
3779
return inp + weight
3880

3981
@staticmethod
4082
def backward(ctx, dy: torch.Tensor):
41-
if not ctx.saved_tensors:
83+
if not hasattr(ctx, "weight"):
4284
return None, None
43-
weight, = ctx.saved_tensors
85+
weight = ctx.weight
4486
weight_grad = dy
4587
weight.sum_grad_squared = dy.square()
4688
if ctx.summed:
@@ -54,15 +96,17 @@ class EinsumFn(torch.autograd.Function):
5496
@staticmethod
5597
def forward(ctx, spec: str, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
5698
if weight.requires_grad:
57-
ctx.save_for_backward(inp, weight)
99+
ctx.save_for_backward(inp)
100+
ctx.weight = weight
58101
ctx.spec = spec
59102
return torch.einsum(spec, inp, weight)
60103

61104
@staticmethod
62105
def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]:
63106
if not ctx.saved_tensors:
64107
return None, None, None
65-
inp, wgt = ctx.saved_tensors
108+
inp, = ctx.saved_tensors
109+
wgt = ctx.weight
66110
inputs, output = ctx.spec.split('->')
67111
lhs, rhs = inputs.split(',')
68112

@@ -76,14 +120,16 @@ class GatherFn(torch.autograd.Function):
76120
@staticmethod
77121
def forward(ctx, inp: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
78122
if weight.requires_grad:
79-
ctx.save_for_backward(inp, weight)
123+
ctx.save_for_backward(inp)
124+
ctx.weight = weight
80125
return torch.gather(weight, 0, inp)
81126

82127
@staticmethod
83128
def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
84129
if not ctx.saved_tensors:
85130
return None, None
86-
inp, wgt = ctx.saved_tensors
131+
inp, = ctx.saved_tensors
132+
wgt = ctx.weight
87133
wgt_grad = torch.zeros_like(wgt)
88134
wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square())
89135
wgt_grad.scatter_add_(0, inp, dy)
@@ -93,45 +139,90 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
93139
class ReshapeFn(torch.autograd.Function):
94140
@staticmethod
95141
def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor:
142+
out = TrueGradTensor(weight.reshape(new_shape).detach().requires_grad_(True))
96143
if weight.requires_grad:
97144
ctx.save_for_backward(weight)
145+
ctx.out = out
98146
ctx.original_shape = weight.size()
99-
return weight.reshape(new_shape)
147+
return out
100148

101149
@staticmethod
102150
def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
103151
if not ctx.saved_tensors:
104-
return None
152+
return None, None
105153
wgt, = ctx.saved_tensors
106-
if hasattr(wgt, "sum_grad_squared"):
107-
wgt.sum_grad_squared = wgt.sum_grad_squared.reshape(ctx.original_shape)
108-
return dy.reshape(ctx.original_shape)
154+
if ctx.out.sum_grad_squared is not None:
155+
wgt.sum_grad_squared = ctx.out.sum_grad_squared.reshape(ctx.original_shape)
156+
return dy.reshape(ctx.original_shape), None
109157

110158

111159
class ExpandFn(torch.autograd.Function):
112160
@staticmethod
113161
def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor:
162+
out = TrueGradTensor(weight.expand(new_shape))
114163
if weight.requires_grad:
115164
ctx.save_for_backward(weight)
165+
ctx.out = out
116166
ctx.summed = [i for i, d in enumerate(new_shape) if d != -1]
117-
return weight.reshape(new_shape)
167+
return out
118168

119169
@staticmethod
120170
def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
121171
if not ctx.saved_tensors:
122-
return None
172+
return None, None
123173
wgt, = ctx.saved_tensors
124-
if hasattr(wgt, "sum_grad_squared") and ctx.summed:
125-
wgt.sum_grad_squared = wgt.sum_grad_squared.sum(ctx.summed)
174+
if ctx.out.sum_grad_squared is not None and ctx.summed:
175+
wgt.sum_grad_squared = ctx.out.sum_grad_squared.sum(ctx.summed)
126176
return dy.sum(ctx.summed)
127177

128178

179+
class WrapFn(torch.autograd.Function):
180+
@staticmethod
181+
def forward(ctx, fn, args, kwargs) -> torch.Tensor:
182+
ctx.fn = fn
183+
ctx.args = args
184+
ctx.kwargs = kwargs
185+
return fn(*args, **kwargs)
186+
187+
@staticmethod
188+
def backward(ctx, dy: torch.Tensor) -> Tuple[None, None, None, None]:
189+
def _backward(fn: Callable[[torch.Tensor], torch.Tensor], attr: str):
190+
def _fn(x: torch.Tensor):
191+
if isinstance(x, torch.nn.Parameter):
192+
x = x.data
193+
if not isinstance(x, torch.Tensor) or not torch.is_floating_point(x):
194+
return x
195+
x = fn(x.detach())
196+
x.requires_grad_(True)
197+
return x
198+
199+
args = tree_map(_fn, ctx.args)
200+
kwargs = tree_map(_fn, ctx.kwargs)
201+
202+
with torch.enable_grad():
203+
out = ctx.fn(args, kwargs)
204+
torch.autograd.backward(out, tree_map(_fn, dy))
205+
206+
for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())):
207+
if not isinstance(p, torch.nn.Parameter):
208+
continue
209+
if hasattr(p, attr) and getattr(p, attr) is not None:
210+
a.grad = getattr(p, attr) + a.grad
211+
setattr(p, attr, a.grad)
212+
213+
_backward(torch.square, "sum_grad_squared")
214+
_backward(lambda x: x, "grad")
215+
216+
return None, None, None, None
217+
218+
129219
mul = MulFn.apply
130220
add = AddFn.apply
131221
einsum = EinsumFn.apply
132222
gather = GatherFn.apply
133223
reshape = ReshapeFn.apply
134224
expand = ExpandFn.apply
225+
wrap = WrapFn.apply
135226

136227

137228
def matmul(inp: torch.Tensor, wgt: torch.Tensor):

0 commit comments

Comments
 (0)