Skip to content

Commit bb868cc

Browse files
committed
Merge branch 'functional'
2 parents 6fae8fd + 98bd65b commit bb868cc

File tree

5 files changed

+91
-16
lines changed

5 files changed

+91
-16
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='1.0.0',
13+
version='2.0.0',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/functional.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
from typing import Any, Callable, List, Tuple
23

34
import torch
@@ -156,6 +157,64 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
156157
return dy.reshape(ctx.original_shape), None
157158

158159

160+
class TransposeFn(torch.autograd.Function):
161+
@staticmethod
162+
def forward(ctx, weight: torch.Tensor, dims: typing.List[int]) -> torch.Tensor:
163+
out = TrueGradTensor(weight.transpose(*dims).detach().requires_grad_(True))
164+
if weight.requires_grad:
165+
ctx.save_for_backward(weight)
166+
ctx.out = out
167+
ctx.dims = dims
168+
return out
169+
170+
@staticmethod
171+
def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
172+
if not ctx.saved_tensors:
173+
return None, None
174+
wgt, = ctx.saved_tensors
175+
if ctx.out.sum_grad_squared is not None:
176+
wgt.sum_grad_squared = ctx.out.sum_grad_squared.transpose(*ctx.dims)
177+
return dy.transpose(*ctx.dims), None
178+
179+
180+
class ChunkFn(torch.autograd.Function):
181+
@staticmethod
182+
def forward(ctx, weight: torch.Tensor, chunks: int, dim: int):
183+
out = tuple(TrueGradTensor(c) for c in weight.chunk(chunks, dim))
184+
if weight.requires_grad:
185+
ctx.save_for_backward(weight)
186+
ctx.out = out
187+
ctx.dim = dim
188+
return out
189+
190+
@staticmethod
191+
def backward(ctx, *dy: torch.Tensor):
192+
if not ctx.saved_tensors:
193+
return None, None, None
194+
wgt, = ctx.saved_tensors
195+
wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim)
196+
return torch.cat(dy, dim=ctx.dim), None, None
197+
198+
199+
class SplitFn(torch.autograd.Function):
200+
@staticmethod
201+
def forward(ctx, weight: torch.Tensor, split_size: int, dim: int):
202+
out = tuple(TrueGradTensor(c) for c in weight.split(split_size, dim))
203+
if weight.requires_grad:
204+
ctx.save_for_backward(weight)
205+
ctx.out = out
206+
ctx.dim = dim
207+
return out
208+
209+
@staticmethod
210+
def backward(ctx, *dy: torch.Tensor):
211+
if not ctx.saved_tensors:
212+
return None, None, None
213+
wgt, = ctx.saved_tensors
214+
wgt.sum_grad_squared = torch.cat([o.sum_grad_squared for o in ctx.out], dim=ctx.dim)
215+
return torch.cat(dy, dim=ctx.dim), None, None
216+
217+
159218
class ExpandFn(torch.autograd.Function):
160219
@staticmethod
161220
def forward(ctx, weight: torch.Tensor, new_shape: List[int]) -> torch.Tensor:
@@ -221,6 +280,9 @@ def _fn(x: torch.Tensor):
221280
einsum = EinsumFn.apply
222281
gather = GatherFn.apply
223282
reshape = ReshapeFn.apply
283+
transpose = TransposeFn.apply
284+
chunk = ChunkFn.apply
285+
split = SplitFn.apply
224286
expand = ExpandFn.apply
225287
wrap = WrapFn.apply
226288

truegrad/nn/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from torch.utils._pytree import tree_map
66

77
from truegrad.functional import add, gather, mul, wrap
8-
from truegrad.nn import functional as F
8+
from truegrad.nn import functional
9+
10+
F = functional
911

1012

1113
class Normalization(nn.Module):
@@ -150,7 +152,7 @@ def __init__(self, num_features: int, eps=1e-05, elementwise_affine=True, device
150152
class Linear(nn.Module):
151153
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
152154
super(Linear, self).__init__()
153-
self.weight = nn.Parameter(torch.randn((in_features, out_features)) / in_features ** 0.5)
155+
self.weight = nn.Parameter(torch.randn((out_features, in_features)) / in_features ** 0.5)
154156
self.bias = nn.Parameter(torch.zeros((out_features,))) if bias else None
155157

156158
def forward(self, x: torch.Tensor) -> torch.Tensor:

truegrad/nn/functional.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import Tensor, nn
99
from torch.nn import functional as F, grad
1010

11-
from truegrad.functional import add, einsum, matmul, mul, reshape
11+
from truegrad.functional import add, chunk, einsum, matmul, mul, reshape, split, transpose
1212

1313
_torch_functional = {k: getattr(F, k) for k in dir(F)}
1414
_torch = {k: getattr(torch, k) for k in dir(torch)}
@@ -551,7 +551,7 @@ def leaky_relu_(input: Tensor, negative_slope: float = 0.01):
551551

552552
@call_torch
553553
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
554-
input = matmul(input, weight)
554+
input = matmul(input, transpose(weight, (0, 1)))
555555
if bias is None:
556556
return input
557557
return add(input, bias)
@@ -641,21 +641,21 @@ def _in_projection_packed(
641641
if k is v:
642642
if q is k:
643643
# self-attention
644-
return linear(q, w, b).chunk(3, dim=-1)
644+
return linear(q, w, b).chunk(3, -1)
645645
else:
646646
# encoder-decoder attention
647-
w_q, w_kv = w.split([E, E * 2])
647+
w_q, w_kv = split(w, [E, E * 2], 0)
648648
if b is None:
649649
b_q = b_kv = None
650650
else:
651-
b_q, b_kv = b.split([E, E * 2])
652-
return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
651+
b_q, b_kv = split(b, [E, E * 2], 0)
652+
return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, -1)
653653
else:
654-
w_q, w_k, w_v = w.chunk(3)
654+
w_q, w_k, w_v = chunk(w, 3, 0)
655655
if b is None:
656656
b_q = b_k = b_v = None
657657
else:
658-
b_q, b_k, b_v = b.chunk(3)
658+
b_q, b_k, b_v = chunk(b, 3, 0)
659659
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
660660

661661

@@ -965,7 +965,7 @@ def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embe
965965
if in_proj_bias is None:
966966
b_q = b_k = b_v = None
967967
else:
968-
b_q, b_k, b_v = in_proj_bias.chunk(3)
968+
b_q, b_k, b_v = chunk(in_proj_bias, 3, 0)
969969
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
970970

971971
# prep attention mask

truegrad/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,20 @@ def _apply_fn(module: torch.nn.Module):
1616
_apply_fn(mod)
1717

1818

19-
def patch_torch():
20-
tg_dir = dir(truegrad.nn)
21-
for name in dir(torch.nn):
19+
def _patch(tg, th):
20+
tg_dir = dir(tg)
21+
for name in dir(th):
2222
if name not in tg_dir:
2323
continue
24-
setattr(torch.nn, name, getattr(truegrad.nn, name))
24+
item = getattr(tg, name)
25+
if not hasattr(item, "__module__"):
26+
continue
27+
if item.__module__ != tg.__name__:
28+
continue
29+
setattr(th, name, item)
30+
31+
32+
def patch_torch():
33+
_patch(truegrad.nn.functional, torch.nn.functional)
34+
_patch(truegrad.nn.functional, torch)
35+
_patch(truegrad.nn, torch.nn)

0 commit comments

Comments
 (0)