Skip to content

Commit acd218b

Browse files
authored
[BUG Fix]Fix 3 bugs for Adam-Mini (#257)
* Update adam_mini.py * Update adam_mini.py * Update adam_mini.py
1 parent 23f32fa commit acd218b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pytorch_optimizer/optimizer/adam_mini.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,18 @@ def step_lefts(
246246
if state['reduced']:
247247
dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)
248248

249-
tmp_lr.div_(state['dim'])
249+
tmp_lr.div_(state['dimension'])
250250

251251
m, v = state['m'], state['v_mean']
252252

253253
m.lerp_(grad, weight=1.0 - beta1)
254-
v.mul_(beta2).add_(tmp_lr, value=1.0 - beta2)
254+
v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)
255255

256256
h = (v.sqrt() / bias_correction2_sq).add_(eps)
257257

258-
update = 1 / (bias_correction1 * h).mul_(m)
258+
stepsize = (1 / bias_correction1) / h
259+
260+
update = m * stepsize
259261

260262
p.add_(update, alpha=-lr)
261263

0 commit comments

Comments
 (0)