Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
ChainRules = "1.35.3"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
Expand All @@ -38,7 +38,7 @@ NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
ZygoteRules = "0.2.1"
julia = "1.3"
julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
29 changes: 0 additions & 29 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,35 +564,6 @@ end
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)

@adjoint function cholesky(Σ::Real)
C = cholesky(Σ)
return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),)
end

@adjoint function cholesky(Σ::Diagonal; check = true)
C = cholesky(Σ, check = check)
return C, Δ::NamedTuple -> begin
issuccess(C) || throw(PosDefException(C.info))
return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing
end
end

# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
@adjoint function cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true)
C = cholesky(Σ, check = check)
return C, function(Δ::NamedTuple)
issuccess(C) || throw(PosDefException(C.info))
U, Ū = C.U, Δ.factors
Σ̄ = similar(U.data)
Σ̄ = mul!(Σ̄, Ū, U')
Σ̄ = copytri!(Σ̄, 'U')
Σ̄ = ldiv!(U, Σ̄)
Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
Σ̄[diagind(Σ̄)] ./= 2
return (UpperTriangular(Σ̄),)
end
end

@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
return X, function (X̄)
Expand Down
13 changes: 12 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ end
g(X) = cholesky(X * X' + I)
@test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈
Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1]
@test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,))

# https://github.com/FluxML/Zygote.jl/issues/932
@test gradcheck(rand(5, 5), rand(5)) do A, x
Expand Down Expand Up @@ -820,6 +819,18 @@ end
@test back′(C̄)[1] isa Diagonal
@test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1])
end
@testset "cholesky - Hermitian" begin
rng, N = MersenneTwister(123456), 3
A = randn(rng, N, N) + im * randn(rng, N, N)
H = Hermitian(A * A' + I)
Hmat = Matrix(H)
y, back = Zygote.pullback(cholesky, Hmat)
y′, back′ = Zygote.pullback(cholesky, H)
C̄ = (factors=randn(rng, N, N),)
@test back′(C̄)[1] isa Hermitian
@test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I)
@test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I)
end
end

@testset "lyap" begin
Expand Down