Skip to content

Commit 3e9bdb2

Browse files
authored
Merge pull request #1714 from wjmaddox/stable_qr_fix
Force add to zero sign in stable_qr
2 parents 17ed99a + ca0a8d7 commit 3e9bdb2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

gpytorch/utils/qr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def stable_qr(mat):
2626
if torch.any(zeroish):
2727
# can't use in-place operation here b/c it would mess up backward pass
2828
# haven't found a more elegant way to add a jitter diagonal yet...
29-
jitter_diag = 1e-6 * torch.sign(Rdiag) * zeroish.to(Rdiag)
29+
Rdiag_sign = torch.sign(Rdiag)
30+
# force zero diagonals to have jitter added to them.
31+
Rdiag_sign[Rdiag_sign == 0] = 1.0
32+
jitter_diag = 1e-6 * Rdiag_sign * zeroish.to(Rdiag)
3033
R = R + torch.diag_embed(jitter_diag)
3134
return Q, R

0 commit comments

Comments
 (0)