Skip to content

Commit c4b4fa9

Browse files
authored
Merge pull request #1114 from st--/st/remove_cholesky_adjoint
remove `@adjoint function cholesky`
2 parents a4d0ad4 + f7203ff commit c4b4fa9

File tree

3 files changed

+27
-31
lines changed

3 files changed

+27
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2525

2626
[compat]
2727
AbstractFFTs = "0.5, 1.0"
28-
ChainRules = "1.33"
28+
ChainRules = "1.35.3"
2929
ChainRulesCore = "1.9"
3030
ChainRulesTestUtils = "1"
3131
DiffRules = "1.4"

src/lib/array.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -564,35 +564,6 @@ end
564564
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
565565
Δ -> (convert(S, Δ),)
566566

567-
@adjoint function cholesky::Real)
568-
C = cholesky(Σ)
569-
return C, Δ::NamedTuple->.factors[1, 1] / (2 * C.U[1, 1]),)
570-
end
571-
572-
@adjoint function cholesky::Diagonal; check = true)
573-
C = cholesky(Σ, check = check)
574-
return C, Δ::NamedTuple -> begin
575-
issuccess(C) || throw(PosDefException(C.info))
576-
return Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing
577-
end
578-
end
579-
580-
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
581-
@adjoint function cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true)
582-
C = cholesky(Σ, check = check)
583-
return C, function::NamedTuple)
584-
issuccess(C) || throw(PosDefException(C.info))
585-
U, Ū = C.U, Δ.factors
586-
Σ̄ = similar(U.data)
587-
Σ̄ = mul!(Σ̄, Ū, U')
588-
Σ̄ = copytri!(Σ̄, 'U')
589-
Σ̄ = ldiv!(U, Σ̄)
590-
Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
591-
Σ̄[diagind(Σ̄)] ./= 2
592-
return (UpperTriangular(Σ̄),)
593-
end
594-
end
595-
596567
@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
597568
X = lyap(A, C)
598569
return X, function (X̄)

test/gradcheck.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ end
654654
g(X) = cholesky(X * X' + I)
655655
@test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1]
656656
Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1]
657-
@test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,))
658657

659658
# https://github.com/FluxML/Zygote.jl/issues/932
660659
@test gradcheck(rand(5, 5), rand(5)) do A, x
@@ -820,6 +819,32 @@ end
820819
@test back′(C̄)[1] isa Diagonal
821820
@test diag(back′(C̄)[1]) diag(back(C̄)[1])
822821
end
822+
@testset "cholesky - Hermitian{Complex}" begin
823+
rng, N = MersenneTwister(123456), 3
824+
A = randn(rng, Complex{Float64}, N, N)
825+
H = Hermitian(A * A' + I)
826+
Hmat = Matrix(H)
827+
y, back = Zygote.pullback(cholesky, Hmat)
828+
y′, back′ = Zygote.pullback(cholesky, H)
829+
= (factors=randn(rng, N, N),)
830+
@test only(back′(C̄)) isa Hermitian
831+
# gradtest does not support complex gradients, even though the pullback exists
832+
d = only(back(C̄))
833+
d′ = only(back′(C̄))
834+
@test (d + d')/2 d′
835+
end
836+
@testset "cholesky - Hermitian{Real}" begin
837+
rng, N = MersenneTwister(123456), 3
838+
A = randn(rng, N, N)
839+
H = Hermitian(A * A' + I)
840+
Hmat = Matrix(H)
841+
y, back = Zygote.pullback(cholesky, Hmat)
842+
y′, back′ = Zygote.pullback(cholesky, H)
843+
= (factors=randn(rng, N, N),)
844+
@test back′(C̄)[1] isa Hermitian
845+
@test gradtest(B->cholesky(Hermitian(B)).U, Hmat)
846+
@test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat)
847+
end
823848
end
824849

825850
@testset "lyap" begin

0 commit comments

Comments
 (0)