Skip to content

Commit 830c82c

Browse files
committed
NaN horrorfix
1 parent 24ecce6 commit 830c82c

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

src/common/safemethods.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
1313
# Inverse
1414

1515
"""
16-
function inv_safe(a::Number, tol = defaulttol(a))
16+
inv_safe(a::Number, tol = defaulttol(a))
1717
1818
Compute the inverse of a number `a`, but return zero if `a` is smaller than `tol`.
1919
"""
2020
inv_safe(a::Number, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a)
21+
function inv_safe(a::ComplexF32, tol = defaulttol(a))
22+
str = string(a) # WHY does this fix the NaN issues??????
23+
return abs(a) < tol ? zero(a) : inv(a)
24+
end

src/pullbacks/eig.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ function eig_pullback!(
4646
Δgauge ≤ gauge_atol ||
4747
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
4848
49-
VᴴΔV ./= conj.(transpose(D) .- D)
50-
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))
49+
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5150
5251
if !iszerotangent(ΔDmat)
5352
ΔDvec = diagview(ΔDmat)

test/enzyme.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ end
164164
rng = StableRNG(12345)
165165
m = 19
166166
atol = rtol = m * m * precision(T)
167-
A = randn(rng, T, m, m)
167+
A = make_eig_matrix(rng, T, m)
168168
D, V = eig_full(A)
169169
Ddiag = diagview(D)
170170
ΔV = randn(rng, complex(T), m, m)
@@ -274,8 +274,9 @@ end
274274
rng = StableRNG(12345)
275275
m = 19
276276
atol = rtol = m * m * precision(T)
277-
A = randn(rng, T, m, m)
278-
A = A + A'
277+
A = make_eigh_matrix(rng, T, m)
278+
Ac = copy(A)
279+
A = (A + A') / 2
279280
D, V = eigh_full(A)
280281
D2 = Diagonal(D)
281282
ΔV = randn(rng, T, m, m)
@@ -289,11 +290,11 @@ end
289290
#LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
290291
)
291292
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
292-
test_reverse(copy_eigh_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
293-
test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294-
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
295-
test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
296-
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
293+
test_reverse(copy_eigh_full, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294+
test_reverse(copy_eigh_full!, RT, (copy(Ac), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
295+
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, Ac, (D, V), (ΔD2, ΔV), alg)
296+
test_reverse(copy_eigh_vals, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
297+
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, Ac, D.diag, ΔD2.diag, alg)
297298
end
298299
@testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
299300
for r in 1:4:m
@@ -304,8 +305,8 @@ end
304305
Vtrunc = V[:, ind]
305306
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
306307
ΔVtrunc = ΔV[:, ind]
307-
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
308-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
308+
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
309310
end
310311
Ddiag = diagview(D)
311312
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -314,8 +315,8 @@ end
314315
Vtrunc = V[:, ind]
315316
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
316317
ΔVtrunc = ΔV[:, ind]
317-
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
318-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
318+
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
319320
end
320321
end
321322
end

0 commit comments

Comments
 (0)