Skip to content

Commit b0eec92

Browse files
committed
add zygote tests
1 parent b38a916 commit b0eec92

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

test/chainrules.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ end
287287
config, last eig_full, A;
288288
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
289289
)
290+
test_rrule(
291+
config, eig_vals, A;
292+
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
293+
)
290294
end
291295

292296
@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32)
@@ -365,6 +369,10 @@ end
365369
config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A;
366370
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
367371
)
372+
test_rrule(
373+
config, eigh_vals ∘ Matrix ∘ Hermitian, A;
374+
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
375+
)
368376
eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...)
369377
for r in 1:4:m
370378
trunc = truncrank(r; by = real)
@@ -459,6 +467,10 @@ end
459467
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol,
460468
rrule_f = rrule_via_ad, check_inferred = false
461469
)
470+
test_rrule(
471+
config, svd_vals, A;
472+
output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
473+
)
462474
for r in 1:4:minmn
463475
trunc = truncrank(r)
464476
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)

0 commit comments

Comments
 (0)