|
1 | 1 | using ChainRulesCore |
2 | 2 | using ChainRulesTestUtils |
3 | | -using FiniteDifferences: FiniteDifferences |
| 3 | +using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm |
4 | 4 | using Random |
5 | 5 | using LinearAlgebra |
6 | 6 | using Zygote |
@@ -305,9 +305,11 @@ for V in spacelist |
305 | 305 | (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) |
306 | 306 | d2 = DiagonalTensorMap{T}(undef, V[1]) |
307 | 307 | (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) |
308 | | - test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) |
309 | | - test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) |
310 | | - test_rrule(f, d; check_inferred, output_tangent=d2) |
| 308 | + |
| 309 | + fdm = (T <: Real && f === sqrt) ? forward_fdm(5, 1) : central_fdm(5, 1) |
| 310 | + test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred, fdm) |
| 311 | + test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred, fdm) |
| 312 | + test_rrule(f, d; check_inferred, output_tangent=d2, fdm) |
311 | 313 | end |
312 | 314 | end |
313 | 315 |
|
@@ -516,7 +518,7 @@ for V in spacelist |
516 | 518 | test_ad_rrule(last ∘ eig_full, t; output_tangent=Δv, atol, rtol) |
517 | 519 | test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol) |
518 | 520 |
|
519 | | - add!(t, t') |
| 521 | + t += t' |
520 | 522 | d, v = eigh_full(t) |
521 | 523 | Δv = rand_tangent(v) |
522 | 524 | Δd = rand_tangent(d) |
|
0 commit comments