Skip to content

Commit 01019af

Browse files
committed
fix some AD tests
1 parent 1ed898a commit 01019af

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

test/ad.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ChainRulesCore
22
using ChainRulesTestUtils
3-
using FiniteDifferences: FiniteDifferences
3+
using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm
44
using Random
55
using LinearAlgebra
66
using Zygote
@@ -302,12 +302,25 @@ for V in spacelist
302302
t1 = randn(T, V[1] V[1])
303303
t2 = randn(T, V[2] V[2])
304304
d = DiagonalTensorMap{T}(undef, V[1])
305-
(T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data)
306305
d2 = DiagonalTensorMap{T}(undef, V[1])
307-
(T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data)
306+
d3 = DiagonalTensorMap{T}(undef, V[1])
307+
if (T <: Real && f === sqrt)
308+
# ensuring no square root of negative numbers
309+
randexp!(d.data)
310+
d.data .+= 5
311+
randexp!(d2.data)
312+
d2.data .+= 5
313+
randexp!(d3.data)
314+
d3.data .+= 5
315+
else
316+
randn!(d.data)
317+
randn!(d2.data)
318+
randn!(d3.data)
319+
end
320+
308321
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred)
309322
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred)
310-
test_rrule(f, d; check_inferred, output_tangent=d2)
323+
test_rrule(f, d d2; check_inferred, output_tangent=d3)
311324
end
312325
end
313326

@@ -516,7 +529,7 @@ for V in spacelist
516529
test_ad_rrule(last eig_full, t; output_tangent=Δv, atol, rtol)
517530
test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol)
518531

519-
add!(t, t')
532+
t += t'
520533
d, v = eigh_full(t)
521534
Δv = rand_tangent(v)
522535
Δd = rand_tangent(d)

0 commit comments

Comments
 (0)