-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
# clip step size to max_step_size, based on a 2nd order expansion.
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) # type: ignore[call-overload]torch.add accept unshaped single value tensor as alpha despite alpha is supposed to be float. torch.where returns a tensor which will cause torch.add to fail. There are sequence of coincidence that the code actually worked.
Bug was discovered during imposing more strict type check.
Steps/Code to reproduce bug
mypy will fail with explicit-override enabled
Expected behavior
step_size should be explicitly float not tensor before calling torch.add.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working