We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 30142b6 commit 91f0ea3Copy full SHA for 91f0ea3
timm/optim/adafactor_bv.py
@@ -146,7 +146,7 @@ def step(self, closure=None):
146
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
147
148
if self.defaults['momentum'] is not None:
149
- state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
+ state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
150
151
state_steps.append(state['step'])
152
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
0 commit comments