Skip to content

Commit 91f0ea3

Browse files
committed
Need to init momentum with correct dtype
1 parent 30142b6 commit 91f0ea3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/optim/adafactor_bv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def step(self, closure=None):
146146
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
147147

148148
if self.defaults['momentum'] is not None:
149-
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
149+
state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
150150

151151
state_steps.append(state['step'])
152152
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))

0 commit comments

Comments
 (0)