We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 17ed99a + ca0a8d7 commit 3e9bdb2Copy full SHA for 3e9bdb2
gpytorch/utils/qr.py
@@ -26,6 +26,9 @@ def stable_qr(mat):
26
if torch.any(zeroish):
27
# can't use in-place operation here b/c it would mess up backward pass
28
# haven't found a more elegant way to add a jitter diagonal yet...
29
- jitter_diag = 1e-6 * torch.sign(Rdiag) * zeroish.to(Rdiag)
+ Rdiag_sign = torch.sign(Rdiag)
30
+ # force zero diagonals to have jitter added to them.
31
+ Rdiag_sign[Rdiag_sign == 0] = 1.0
32
+ jitter_diag = 1e-6 * Rdiag_sign * zeroish.to(Rdiag)
33
R = R + torch.diag_embed(jitter_diag)
34
return Q, R
0 commit comments