|
1 | 1 | using ChainRulesCore |
2 | 2 | using ChainRulesTestUtils |
3 | | -using FiniteDifferences: FiniteDifferences |
| 3 | +using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm |
4 | 4 | using Random |
5 | 5 | using LinearAlgebra |
6 | 6 | using Zygote |
@@ -302,12 +302,25 @@ for V in spacelist |
302 | 302 | t1 = randn(T, V[1] ← V[1]) |
303 | 303 | t2 = randn(T, V[2] ← V[2]) |
304 | 304 | d = DiagonalTensorMap{T}(undef, V[1]) |
305 | | - (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) |
306 | 305 | d2 = DiagonalTensorMap{T}(undef, V[1]) |
307 | | - (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) |
| 306 | + d3 = DiagonalTensorMap{T}(undef, V[1]) |
| 307 | + if (T <: Real && f === sqrt) |
| 308 | + # ensuring no square root of negative numbers |
| 309 | + randexp!(d.data) |
| 310 | + d.data .+= 5 |
| 311 | + randexp!(d2.data) |
| 312 | + d2.data .+= 5 |
| 313 | + randexp!(d3.data) |
| 314 | + d3.data .+= 5 |
| 315 | + else |
| 316 | + randn!(d.data) |
| 317 | + randn!(d2.data) |
| 318 | + randn!(d3.data) |
| 319 | + end |
| 320 | + |
308 | 321 | test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) |
309 | 322 | test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) |
310 | | - test_rrule(f, d; check_inferred, output_tangent=d2) |
| 323 | + test_rrule(f, d ⊢ d2; check_inferred, output_tangent=d3) |
311 | 324 | end |
312 | 325 | end |
313 | 326 |
|
@@ -516,7 +529,7 @@ for V in spacelist |
516 | 529 | test_ad_rrule(last ∘ eig_full, t; output_tangent=Δv, atol, rtol) |
517 | 530 | test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol) |
518 | 531 |
|
519 | | - add!(t, t') |
| 532 | + t += t' |
520 | 533 | d, v = eigh_full(t) |
521 | 534 | Δv = rand_tangent(v) |
522 | 535 | Δd = rand_tangent(d) |
|
0 commit comments