Skip to content

Commit fa62fbf

Browse files
committed
fix(optim): flip grad
1 parent 1ce7788 commit fa62fbf

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='3.1.0',
13+
version='3.1.1',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def step(self, closure=None):
7272
if "old_update" in state:
7373
dims = ''.join(chr(ord('a') + i) for i in range(update.ndim))
7474
lr_grad = torch.einsum(f"{dims},{dims}->", update, state["old_update"].double())
75-
state["lr"] = group["lr"] = group["lr"] - lr_grad.item() * self.learning_rate_learning_rate
75+
state["lr"] = group["lr"] = group["lr"] + lr_grad.item() * self.learning_rate_learning_rate
7676
state["old_update"] = torch.clone(update.to(torch.bfloat16).detach())
7777
state["param"] = None
7878

0 commit comments

Comments
 (0)