|
4 | 4 | T in (Float64, ComplexF64),
|
5 | 5 | uplo in (:U, :L)
|
6 | 6 |
|
7 |
| - N = 3 |
8 | 7 | @testset "frule" begin
|
9 |
| - x = randn(T, N, N) |
10 |
| - Δx = randn(T, N, N) |
11 |
| - # can't use frule_test here because it doesn't yet ignore nothing tangents |
12 |
| - Ω = SymHerm(x, uplo) |
13 |
| - Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo) |
14 |
| - @test Ω_ad == Ω |
15 |
| - ∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx)) |
16 |
| - @test ∂Ω_ad ≈ ∂Ω_fd |
| 8 | + test_frule(SymHerm, rand(T, 3, 3), uplo ⊢ nothing) |
17 | 9 | end
|
18 | 10 | @testset "rrule" begin
|
19 | 11 | # on old versions of julia this combination doesn't infer but we don't care as
|
20 | 12 | # it infers fine on modern versions.
|
21 | 13 | check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
|
22 | 14 |
|
23 |
| - x = randn(T, N, N) |
24 |
| - ∂x = randn(T, N, N) |
25 |
| - ΔΩ = randn(T, N, N) |
26 | 15 | @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
|
27 |
| - rrule_test( |
28 |
| - SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing); |
| 16 | + x = randn(T, 3, 3) |
| 17 | + ΔΩ = MT(randn(T, 3, 3)) |
| 18 | + test_rrule( |
| 19 | + SymHerm, x, uplo ⊢ nothing; |
| 20 | + output_tangent = ΔΩ, |
29 | 21 | # type stability here critically relies on uplo being constant propagated,
|
30 | 22 | # so we need to test this more carefully below
|
31 | 23 | check_inferred=false,
|
32 | 24 | )
|
33 | 25 | if check_inferred
|
34 |
| - @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} |
| 26 | + @inferred (function (SymHerm, x, ΔΩ, ::Val) |
35 | 27 | return rrule(SymHerm, x, uplo)[2](ΔΩ)
|
36 |
| - end)(SymHerm, x, MT(ΔΩ), Val(uplo)) |
| 28 | + end)(SymHerm, x, ΔΩ, Val(uplo)) |
37 | 29 | end
|
38 | 30 | end
|
39 | 31 | @testset "back(::Diagonal)" begin
|
40 |
| - rrule_test( |
41 |
| - SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing); |
| 32 | + x = randn(T, 3, 3) |
| 33 | + ΔΩ = Diagonal(randn(T, 3, 3)) |
| 34 | + test_rrule( |
| 35 | + SymHerm, x ⊢ Diagonal(randn(T, 3)), uplo ⊢ nothing; |
42 | 36 | check_inferred=false,
|
| 37 | + output_tangent = ΔΩ, |
43 | 38 | )
|
44 | 39 | if check_inferred
|
45 |
| - @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} |
| 40 | + @inferred (function (SymHerm, x, ΔΩ, ::Val) |
46 | 41 | return rrule(SymHerm, x, uplo)[2](ΔΩ)
|
47 |
| - end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo)) |
| 42 | + end)(SymHerm, x, ΔΩ, Val(uplo)) |
48 | 43 | end
|
49 | 44 | end
|
50 | 45 | end
|
51 | 46 | end
|
| 47 | + # constructing a `Matrix`/`Array` from `SymHerm` |
52 | 48 | @testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
|
53 | 49 | SymHerm in (Symmetric, Hermitian),
|
54 | 50 | T in (Float64, ComplexF64),
|
55 | 51 | uplo in (:U, :L)
|
56 | 52 |
|
57 |
| - N = 3 |
58 |
| - x = SymHerm(randn(T, N, N), uplo) |
59 |
| - Δx = randn(T, N, N) |
60 |
| - ∂x = SymHerm(randn(T, N, N), uplo) |
61 |
| - ΔΩ = f(SymHerm(randn(T, N, N), uplo)) |
62 |
| - frule_test(f, (x, Δx)) |
63 |
| - frule_test(f, (x, SymHerm(Δx, uplo))) |
64 |
| - rrule_test(f, ΔΩ, (x, ∂x)) |
| 53 | + x = SymHerm(randn(T, 3, 3), uplo) |
| 54 | + test_rrule(f, x) |
| 55 | + |
| 56 | + # intentionally specifying tangents here to test both Matrix and SymHerm tangents |
| 57 | + test_frule(f, x ⊢ randn(T, 3, 3)) |
| 58 | + test_frule(f, x ⊢ SymHerm(randn(T, 3, 3), uplo)) |
65 | 59 | end
|
66 | 60 |
|
67 | 61 | # symmetric/hermitian eigendecomposition follows the sign convention
|
68 | 62 | # v = v * sign(real(vₖ)) * sign(vₖ)', where vₖ is the first or last coordinate
|
69 | 63 | # in the eigenvector. This is unstable for finite differences, but using the convention
|
70 | 64 | # v = v * sign(vₖ)' seems to be more stable, the (co)tangents are related as
|
71 | 65 | # ∂v_ad = sign(real(vₖ)) * ∂v_fd
|
72 |
| - |
73 | 66 | function _eigvecs_stabilize_mat(vectors, uplo)
|
74 | 67 | Ui = Symbol(uplo) === :U ? @view(vectors[end, :]) : @view(vectors[1, :])
|
75 | 68 | return Diagonal(conj.(sign.(Ui)))
|
|
0 commit comments