Skip to content

Commit ff82129

Browse files
committed
fix some AD tests
1 parent 3357335 commit ff82129

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

test/ad.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ChainRulesCore
22
using ChainRulesTestUtils
3-
using FiniteDifferences: FiniteDifferences
3+
using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm
44
using Random
55
using LinearAlgebra
66
using Zygote
@@ -305,9 +305,11 @@ for V in spacelist
305305
(T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data)
306306
d2 = DiagonalTensorMap{T}(undef, V[1])
307307
(T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data)
308-
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred)
309-
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred)
310-
test_rrule(f, d; check_inferred, output_tangent=d2)
308+
309+
fdm = (T <: Real && f === sqrt) ? forward_fdm(5, 1) : central_fdm(5, 1)
310+
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred, fdm)
311+
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred, fdm)
312+
test_rrule(f, d; check_inferred, output_tangent=d2, fdm)
311313
end
312314
end
313315

@@ -516,7 +518,7 @@ for V in spacelist
516518
test_ad_rrule(last eig_full, t; output_tangent=Δv, atol, rtol)
517519
test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol)
518520

519-
add!(t, t')
521+
t += t'
520522
d, v = eigh_full(t)
521523
Δv = rand_tangent(v)
522524
Δd = rand_tangent(d)

0 commit comments

Comments
 (0)