We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2fe259f commit ddf81a7Copy full SHA for ddf81a7
pytorch_optimizer/optimizer/dadapt.py
@@ -196,7 +196,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
196
197
if lr > 0.0:
198
d_hat = (sk_sq_weighted - gsq_weighted) / sk_l1
199
- d = group['d'] = max(d, min(d_hat, d * group['growth_rate']))
+ d = group['d'] = max(d, min(d_hat.item(), d * group['growth_rate']))
200
201
for group in self.param_groups:
202
group['gsq_weighted'] = gsq_weighted
0 commit comments