|
4 | 4 | T in (Float64, ComplexF64),
|
5 | 5 | uplo in (:U, :L)
|
6 | 6 |
|
7 |
| - N = 3 |
8 | 7 | @testset "frule" begin
|
9 |
| - x = randn(T, N, N) |
10 |
| - Δx = randn(T, N, N) |
11 |
| - # can't use frule_test here because it doesn't yet ignore nothing tangents |
12 |
| - Ω = SymHerm(x, uplo) |
13 |
| - Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo) |
14 |
| - @test Ω_ad == Ω |
15 |
| - ∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx)) |
16 |
| - @test ∂Ω_ad ≈ ∂Ω_fd |
| 8 | + test_frule(SymHerm, rand(T, 3, 3), uplo ⊢ nothing) |
17 | 9 | end
|
18 | 10 | @testset "rrule" begin
|
19 | 11 | # on old versions of julia this combination doesn't infer but we don't care as
|
20 | 12 | # it infers fine on modern versions.
|
21 | 13 | check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
|
22 | 14 |
|
23 |
| - x = randn(T, N, N) |
24 |
| - ∂x = randn(T, N, N) |
25 |
| - ΔΩ = randn(T, N, N) |
26 | 15 | @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
|
27 |
| - rrule_test( |
28 |
| - SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing); |
| 16 | + x = randn(T, 3, 3) |
| 17 | + ΔΩ = MT(randn(T, 3, 3)) |
| 18 | + test_rrule( |
| 19 | + SymHerm, x, uplo ⊢ nothing; |
| 20 | + output_tangent = ΔΩ, |
29 | 21 | # type stability here critically relies on uplo being constant propagated,
|
30 | 22 | # so we need to test this more carefully below
|
31 | 23 | check_inferred=false,
|
32 | 24 | )
|
33 | 25 | if check_inferred
|
34 |
| - @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} |
| 26 | + @inferred (function (SymHerm, x, ΔΩ, ::Val) |
35 | 27 | return rrule(SymHerm, x, uplo)[2](ΔΩ)
|
36 |
| - end)(SymHerm, x, MT(ΔΩ), Val(uplo)) |
| 28 | + end)(SymHerm, x, ΔΩ, Val(uplo)) |
37 | 29 | end
|
38 | 30 | end
|
39 | 31 | @testset "back(::Diagonal)" begin
|
40 |
| - rrule_test( |
| 32 | + x = randn(T, 3, 3) |
| 33 | + ΔΩ = Diagonal(randn(T, 3, 3)) |
| 34 | + test_rrule( |
41 | 35 | SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
|
42 | 36 | check_inferred=false,
|
43 | 37 | )
|
44 | 38 | if check_inferred
|
45 |
| - @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} |
| 39 | + @inferred (function (SymHerm, x, ΔΩ, ::Val) |
46 | 40 | return rrule(SymHerm, x, uplo)[2](ΔΩ)
|
47 | 41 | end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo))
|
48 | 42 | end
|
|
53 | 47 | SymHerm in (Symmetric, Hermitian),
|
54 | 48 | T in (Float64, ComplexF64),
|
55 | 49 | uplo in (:U, :L)
|
56 |
| - |
57 |
| - N = 3 |
58 |
| - x = SymHerm(randn(T, N, N), uplo) |
59 |
| - Δx = randn(T, N, N) |
60 |
| - ∂x = SymHerm(randn(T, N, N), uplo) |
61 |
| - ΔΩ = f(SymHerm(randn(T, N, N), uplo)) |
62 |
| - frule_test(f, (x, Δx)) |
63 |
| - frule_test(f, (x, SymHerm(Δx, uplo))) |
64 |
| - rrule_test(f, ΔΩ, (x, ∂x)) |
| 50 | + x = SymHerm(randn(T, 3, 3), uplo) |
| 51 | + Δx = randn(T, 3, 3) |
| 52 | + ∂x = SymHerm(randn(T, 3, 3), uplo) |
| 53 | + ΔΩ = f(SymHerm(randn(T, 3, 3), uplo)) |
| 54 | + test_frule(f, (x, Δx)) |
| 55 | + test_frule(f, (x, SymHerm(Δx, uplo))) |
| 56 | + test_rrule(f, ΔΩ, (x, ∂x)) |
65 | 57 | end
|
66 | 58 |
|
67 | 59 | # symmetric/hermitian eigendecomposition follows the sign convention
|
|
0 commit comments