Skip to content

Commit 8f6eccc

Browse files
committed
finish adding autotangents
1 parent 0023450 commit 8f6eccc

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

test/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,37 @@
3232
x = randn(T, 3, 3)
3333
ΔΩ = Diagonal(randn(T, 3, 3))
3434
test_rrule(
35-
SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing);
35+
SymHerm, x Diagonal(randn(T, 3)), uplo nothing;
3636
check_inferred=false,
37+
output_tangent = ΔΩ,
3738
)
3839
if check_inferred
3940
@inferred (function (SymHerm, x, ΔΩ, ::Val)
4041
return rrule(SymHerm, x, uplo)[2](ΔΩ)
41-
end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo))
42+
end)(SymHerm, x, ΔΩ, Val(uplo))
4243
end
4344
end
4445
end
4546
end
47+
# constructing a `Matrix`/`Array` from `SymHerm`
4648
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
4749
SymHerm in (Symmetric, Hermitian),
4850
T in (Float64, ComplexF64),
4951
uplo in (:U, :L)
52+
5053
x = SymHerm(randn(T, 3, 3), uplo)
51-
Δx = randn(T, 3, 3)
52-
∂x = SymHerm(randn(T, 3, 3), uplo)
53-
ΔΩ = f(SymHerm(randn(T, 3, 3), uplo))
54-
test_frule(f, (x, Δx))
55-
test_frule(f, (x, SymHerm(Δx, uplo)))
56-
test_rrule(f, ΔΩ, (x, ∂x))
54+
test_rrule(f, x)
55+
56+
# intentionally specifying tangents here to test both Matrix and SymHerm tangents
57+
test_frule(f, x randn(T, 3, 3))
58+
test_frule(f, x SymHerm(randn(T, 3, 3), uplo))
5759
end
5860

5961
# symmetric/hermitian eigendecomposition follows the sign convention
6062
# v = v * sign(real(vₖ)) * sign(vₖ)', where vₖ is the first or last coordinate
6163
# in the eigenvector. This is unstable for finite differences, but using the convention
6264
# v = v * sign(vₖ)' seems to be more stable, the (co)tangents are related as
6365
# ∂v_ad = sign(real(vₖ)) * ∂v_fd
64-
6566
function _eigvecs_stabilize_mat(vectors, uplo)
6667
Ui = Symbol(uplo) === :U ? @view(vectors[end, :]) : @view(vectors[1, :])
6768
return Diagonal(conj.(sign.(Ui)))

0 commit comments

Comments
 (0)