Skip to content

Dangerous implicit casting in procrustes_step #85

@skyw

Description

@skyw

Describe the bug

            # clip step size to max_step_size, based on a 2nd order expansion.
            _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
            # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
            step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
            # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
            # for 2nd order expansion, only expand exp(a R) to its 2nd term.
            # Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
            Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)  # type: ignore[call-overload]

torch.add accept unshaped single value tensor as alpha despite alpha is supposed to be float. torch.where returns a tensor which will cause torch.add to fail. There are sequence of coincidence that the code actually worked.
Bug was discovered during imposing more strict type check.

Steps/Code to reproduce bug

mypy will fail with explicit-override enabled

Expected behavior

step_size should be explicitly float not tensor before calling torch.add.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions