164164 rng = StableRNG(12345 )
165165 m = 19
166166 atol = rtol = m * m * precision(T)
167- A = randn (rng, T, m , m)
167+ A = make_eig_matrix (rng, T, m)
168168 D, V = eig_full(A)
169169 Ddiag = diagview(D)
170170 ΔV = randn(rng, complex(T), m, m)
274274 rng = StableRNG(12345 )
275275 m = 19
276276 atol = rtol = m * m * precision(T)
277- A = randn(rng, T, m, m)
278- A = A + A'
277+ A = make_eigh_matrix(rng, T, m)
278+ Ac = copy(A)
279+ A = (A + A' ) / 2
279280 D, V = eigh_full(A)
280281 D2 = Diagonal(D)
281282 ΔV = randn(rng, T, m, m)
@@ -289,11 +290,11 @@ end
289290 #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
290291 )
291292 @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
292- test_reverse(copy_eigh_full, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
293- test_reverse(copy_eigh_full!, RT, (copy(A ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294- test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A , (D, V), (ΔD2, ΔV), alg)
295- test_reverse(copy_eigh_vals, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
296- test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A , D.diag, ΔD2.diag, alg)
293+ test_reverse(copy_eigh_full, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294+ test_reverse(copy_eigh_full!, RT, (copy(Ac ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
295+ test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, Ac , (D, V), (ΔD2, ΔV), alg)
296+ test_reverse(copy_eigh_vals, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
297+ test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, Ac , D.diag, ΔD2.diag, alg)
297298 end
298299 @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
299300 for r in 1:4:m
304305 Vtrunc = V[:, ind]
305306 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
306307 ΔVtrunc = ΔV[:, ind]
307- test_reverse(copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
308- test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
308+ test_reverse(copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309+ test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
309310 end
310311 Ddiag = diagview(D)
311312 truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
314315 Vtrunc = V[:, ind]
315316 ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
316317 ΔVtrunc = ΔV[:, ind]
317- test_reverse(copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
318- test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
318+ test_reverse(copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319+ test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
319320 end
320321 end
321322end
0 commit comments