Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.32.1"
version = "1.33.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
28 changes: 28 additions & 0 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,31 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
end
return getproperty(F, x), getproperty_cholesky_pullback
end

# `det` and `logdet` for `Cholesky`
function rrule(::typeof(det), C::Cholesky)
y = det(C)
diagF = _diag_view(C.factors)
function det_Cholesky_pullback(ȳ)
ΔF = Diagonal(_x_divide_conj_y.(2 * ȳ * conj(y), diagF))
ΔC = Tangent{typeof(C)}(; factors=ΔF)
return NoTangent(), ΔC
end
return y, det_Cholesky_pullback
end
# compute `x / conj(y)`, handling `x = y = 0`
function _x_divide_conj_y(x, y)
z = x / conj(y)
# in our case `iszero(x)` implies `iszero(y)`
return iszero(x) ? zero(z) : z
end

function rrule(::typeof(logdet), C::Cholesky)
y = logdet(C)
diagF = _diag_view(C.factors)
function logdet_Cholesky_pullback(ȳ)
ΔC = Tangent{typeof(C)}(; factors=Diagonal((2 * ȳ) ./ conj.(diagF)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect there's something that can be done here as well to make it more NaN-safe, but I think this should not block this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering that as well (there are some - now hidden - comments above) but it felt like usually we don't handle such things in a special way if it can only be triggered by specific cotangents but is not an immediate consequence of the inputs. Or do we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we test this for all the rules we should, but a principle that we seem to be agreed on is that zero (co)tangents should be strong zeros (see e.g. JuliaDiff/ChainRulesCore.jl#551 (comment)).

So in this case if ȳ==0, then the cotangent of factors should be a zero matrix. Otherwise you end up with cases like zero(logdet(cholesky(A; check=false))), which pulls back a zero cotangent through this rule, injecting NaN's into all downstream cotangents, even though the output is unrelated to the value of A.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I was not aware of this principle. Would maybe good to add it to the docs and possibly CRTestUtils 🙂

I'll update the PR accordingly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in d831cd4

return NoTangent(), ΔC
end
return y, logdet_Cholesky_pullback
end
30 changes: 30 additions & 0 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,5 +432,35 @@ end
ΔX_symmetric = chol_back_sym(Δ)[2]
@test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2]
end

@testset "det and logdet (uplo=$p)" for p in (:U, :L)
@testset "$op" for op in (det, logdet)
@testset "$T" for T in (Float64, ComplexF64)
n = 5
# rand (not randn) so det will be postive, so logdet will be defined
A = 3 * rand(T, (n, n))
X = Cholesky(A * A' + I, p, 0)
X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal
test_rrule(op, X ⊢ X̄_acc)

# return type
_, op_pullback = rrule(op, X)
X̄ = op_pullback(2.7)[2]
@test X̄ isa Tangent{<:Cholesky}
@test X̄.factors isa Diagonal
end
end

@testset "singular ($T)" for T in (Float64, ComplexF64)
n = 5
L = LowerTriangular(randn(T, (n, n)))
L[1, 1] = zero(T)
X = cholesky(L * L'; check=false)
detX, det_pullback = rrule(det, X)
ΔX = det_pullback(rand())[2]
@test iszero(detX)
@test ΔX.factors isa Diagonal && all(iszero, ΔX.factors)
end
end
end
end