Skip to content

Commit 24318b0

Browse files
authored
Merge pull request #378 from JuliaDiff/ox/tansym
symmetric.jl autotangent
2 parents d6e9806 + 8f6eccc commit 24318b0

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

test/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,72 +4,65 @@
44
T in (Float64, ComplexF64),
55
uplo in (:U, :L)
66

7-
N = 3
87
@testset "frule" begin
9-
x = randn(T, N, N)
10-
Δx = randn(T, N, N)
11-
# can't use frule_test here because it doesn't yet ignore nothing tangents
12-
Ω = SymHerm(x, uplo)
13-
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
14-
@test Ω_ad == Ω
15-
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
16-
@test ∂Ω_ad ∂Ω_fd
8+
test_frule(SymHerm, rand(T, 3, 3), uplo nothing)
179
end
1810
@testset "rrule" begin
1911
# on old versions of julia this combination doesn't infer but we don't care as
2012
# it infers fine on modern versions.
2113
check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian)
2214

23-
x = randn(T, N, N)
24-
∂x = randn(T, N, N)
25-
ΔΩ = randn(T, N, N)
2615
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
27-
rrule_test(
28-
SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing);
16+
x = randn(T, 3, 3)
17+
ΔΩ = MT(randn(T, 3, 3))
18+
test_rrule(
19+
SymHerm, x, uplo nothing;
20+
output_tangent = ΔΩ,
2921
# type stability here critically relies on uplo being constant propagated,
3022
# so we need to test this more carefully below
3123
check_inferred=false,
3224
)
3325
if check_inferred
34-
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
26+
@inferred (function (SymHerm, x, ΔΩ, ::Val)
3527
return rrule(SymHerm, x, uplo)[2](ΔΩ)
36-
end)(SymHerm, x, MT(ΔΩ), Val(uplo))
28+
end)(SymHerm, x, ΔΩ, Val(uplo))
3729
end
3830
end
3931
@testset "back(::Diagonal)" begin
40-
rrule_test(
41-
SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
32+
x = randn(T, 3, 3)
33+
ΔΩ = Diagonal(randn(T, 3, 3))
34+
test_rrule(
35+
SymHerm, x Diagonal(randn(T, 3)), uplo nothing;
4236
check_inferred=false,
37+
output_tangent = ΔΩ,
4338
)
4439
if check_inferred
45-
@inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo}
40+
@inferred (function (SymHerm, x, ΔΩ, ::Val)
4641
return rrule(SymHerm, x, uplo)[2](ΔΩ)
47-
end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo))
42+
end)(SymHerm, x, ΔΩ, Val(uplo))
4843
end
4944
end
5045
end
5146
end
47+
# constructing a `Matrix`/`Array` from `SymHerm`
5248
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
5349
SymHerm in (Symmetric, Hermitian),
5450
T in (Float64, ComplexF64),
5551
uplo in (:U, :L)
5652

57-
N = 3
58-
x = SymHerm(randn(T, N, N), uplo)
59-
Δx = randn(T, N, N)
60-
∂x = SymHerm(randn(T, N, N), uplo)
61-
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
62-
frule_test(f, (x, Δx))
63-
frule_test(f, (x, SymHerm(Δx, uplo)))
64-
rrule_test(f, ΔΩ, (x, ∂x))
53+
x = SymHerm(randn(T, 3, 3), uplo)
54+
test_rrule(f, x)
55+
56+
# intentionally specifying tangents here to test both Matrix and SymHerm tangents
57+
test_frule(f, x randn(T, 3, 3))
58+
test_frule(f, x SymHerm(randn(T, 3, 3), uplo))
6559
end
6660

6761
# symmetric/hermitian eigendecomposition follows the sign convention
6862
# v = v * sign(real(vₖ)) * sign(vₖ)', where vₖ is the first or last coordinate
6963
# in the eigenvector. This is unstable for finite differences, but using the convention
7064
# v = v * sign(vₖ)' seems to be more stable, the (co)tangents are related as
7165
# ∂v_ad = sign(real(vₖ)) * ∂v_fd
72-
7366
function _eigvecs_stabilize_mat(vectors, uplo)
7467
Ui = Symbol(uplo) === :U ? @view(vectors[end, :]) : @view(vectors[1, :])
7568
return Diagonal(conj.(sign.(Ui)))

0 commit comments

Comments
 (0)