|
6 | 6 | nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
|
7 | 7 |
|
8 | 8 | A = randn(T, n, n)
|
9 |
| - ΔA = randn(T, n, n) |
10 | 9 | A *= nrm / opnorm(A, 1)
|
11 | 10 | tols = nrm == 0.1 ? (atol=1e-8, rtol=1e-8) : NamedTuple()
|
12 |
| - frule_test(LinearAlgebra.exp!, (A, ΔA); tols...) |
| 11 | + test_frule(LinearAlgebra.exp!, A; tols...) |
13 | 12 | end
|
14 | 13 | @testset "imbalanced A" begin
|
15 | 14 | A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
|
16 |
| - ΔA = rand_tangent(A) |
17 |
| - frule_test(LinearAlgebra.exp!, (A, ΔA)) |
| 15 | + test_frule(LinearAlgebra.exp!, A) |
18 | 16 | end
|
19 | 17 | @testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
|
20 | 18 | A = Matrix(Hermitian(randn(T, n, n)))
|
21 |
| - ΔA = randn(T, n, n) |
22 |
| - frule_test(LinearAlgebra.exp!, (A, Matrix(Hermitian(ΔA)))) |
23 |
| - frule_test(LinearAlgebra.exp!, (A, ΔA)) |
| 19 | + test_frule(LinearAlgebra.exp!, A) |
| 20 | + test_frule(LinearAlgebra.exp!, A ⊢ Matrix(Hermitian(randn(T, n, n)))) |
24 | 21 | end
|
25 | 22 | end
|
26 | 23 |
|
|
31 | 28 | nrm in (0.01, 0.1, 0.5, 1.5, 3.0, 6.0, 12.0)
|
32 | 29 |
|
33 | 30 | A = randn(T, n, n)
|
34 |
| - ΔA = randn(T, n, n) |
35 |
| - ΔY = randn(T, n, n) |
36 | 31 | A *= nrm / opnorm(A, 1)
|
37 | 32 | # rrule is not inferrable, but pullback should be
|
38 | 33 | tols = nrm == 0.1 ? (atol=1e-8, rtol=1e-8) : NamedTuple()
|
39 |
| - rrule_test(exp, ΔY, (A, ΔA); check_inferred=false, tols...) |
| 34 | + test_rrule(exp, A; check_inferred=false, tols...) |
40 | 35 | Y, back = rrule(exp, A)
|
41 |
| - @inferred back(ΔY) |
| 36 | + @inferred back(rand_tangent(Y)) |
42 | 37 | end
|
43 | 38 | @testset "imbalanced A" begin
|
44 | 39 | A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
|
45 |
| - ΔA = rand_tangent(A) |
46 |
| - ΔY = rand_tangent(exp(A)) |
47 |
| - rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) |
| 40 | + test_rrule(exp, A; check_inferred=false) |
48 | 41 | end
|
49 | 42 | @testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
|
50 | 43 | A = Matrix(Hermitian(randn(T, n, n)))
|
51 |
| - ΔA = randn(T, n, n) |
52 |
| - ΔY = randn(T, n, n) |
53 |
| - rrule_test(exp, Matrix(Hermitian(ΔY)), (A, ΔA); check_inferred=false) |
54 |
| - rrule_test(exp, ΔY, (A, ΔA); check_inferred=false) |
| 44 | + test_rrule(exp, A; check_inferred=false) |
| 45 | + test_rrule( |
| 46 | + exp, A; |
| 47 | + check_inferred=false, output_tangent=Matrix(Hermitian(randn(T, n, n))) |
| 48 | + ) |
55 | 49 | end
|
56 | 50 | end
|
57 | 51 | end
|
0 commit comments