Skip to content

Commit e8cfb9b

Browse files
committed
fix: d_numerator device
1 parent 7811dd2 commit e8cfb9b

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytorch_optimizer/optimizer/prodigy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
110110

111111
if 'd_numerator' not in group:
112112
group['d_numerator'] = torch.tensor([0.0], device=device)
113+
elif group['d_numerator'].device != device:
114+
group['d_numerator'] = group['d_numerator'].to(device)
113115

114116
d_numerator = group['d_numerator']
115117
d_numerator.mul_(beta3)

0 commit comments

Comments
 (0)