From 7c808c7ffa151fecc34c40236bb59fe7abe93699 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 12 Sep 2019 13:58:11 +0800 Subject: [PATCH] WIP: add adjoint for eigen --- src/lib/array.jl | 18 ++++++++++++++++++ test/gradcheck.jl | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index f2431aa52..75a8bb7ff 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -403,6 +403,24 @@ end (Ā, ) end +@adjoint function LinearAlgebra.eigen(A::AbstractMatrix) + eV = eigen(A) + e,V = eV + n = size(A,1) + eV, function (Δ) + Δe, ΔV = Δ + if ΔV === nothing + (inv(V)'*Diagonal(Δe)*V', ) + elseif Δe === nothing + F = [i==j ? 0 : inv(e[j] - e[i]) for i=1:n, j=1:n] + (inv(V)'*(F .* (V'ΔV))*V', ) + else + F = [i==j ? 0 : inv(e[j] - e[i]) for i=1:n, j=1:n] + (inv(V)'*(Diagonal(Δe) + F .* (V'ΔV))*V', ) + end + end +end + Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) # x is a squre matrix checked by tr, # so we could just use Eye(size(x, 1)) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 8b460e2c2..70256fd77 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -352,6 +352,18 @@ end end end +@testset "eigen" begin + rng, N = MersenneTwister(6865931), 8 + for i = 1:5 + A = randn(rng, N, N) + @test gradtest(A->abs.(eigen(A).values), A) + @test gradcheck(A) do A + e = eigen(A) + sum(real.(e.values)) + sum(real.(e.vectors)) + end + end +end + using Distances Zygote.refresh()