-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAdamOptimizer.py
More file actions
40 lines (35 loc) · 1.42 KB
/
AdamOptimizer.py
File metadata and controls
40 lines (35 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
class SimpleAdamOptimizer:
def __init__(self, params, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
self.params = list(params)
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.t = 0
self.ms = [torch.zeros_like(param) for param in self.params]
self.vs = [torch.zeros_like(param) for param in self.params]
def step(self):
self.t += 1
for i, param in enumerate(self.params):
gradf = param.grad
if gradf is None:
continue
else:
grad = param.grad(is_leaf=True)
m, v = self.ms[i], self.vs[i]
# Update biased first and second moment estimates
m.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
v.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
# Compute bias-corrected first and second moment estimates
m_hat = m.div(1 - self.beta1 ** self.t)
v_hat = v.div(1 - self.beta2 ** self.t)
# Update the parameter using a copy operation
param_update = m_hat.div(v_hat.sqrt().add_(self.eps)).mul(-self.lr)
with torch.no_grad():
param += param_update
param.grad = None
def zero_grad(self):
for param in self.params:
if param.grad is not None:
param.grad.zero_()