Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.30"
version = "0.6.31"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
12 changes: 8 additions & 4 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,15 +741,19 @@ end
return ((uplo=nothing, info=nothing, factors=nothing),)
end
end
@adjoint function literal_getproperty(C::Cholesky, ::Val{:U})
@adjoint function literal_getproperty(
C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:U}
)
return literal_getproperty(C, Val(:U)), function(Δ)
Δ_factors = C.uplo == 'U' ? UpperTriangular(Δ) : LowerTriangular(copy(Δ'))
Δ_factors = C.uplo == 'U' ? triu!(collect(Δ)) : tril!(collect(Δ'))
return ((uplo=nothing, info=nothing, factors=Δ_factors),)
end
end
@adjoint function literal_getproperty(C::Cholesky, ::Val{:L})
@adjoint function literal_getproperty(
C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:L}
)
return literal_getproperty(C, Val(:L)), function(Δ)
Δ_factors = C.uplo == 'L' ? LowerTriangular(Δ) : UpperTriangular(copy(Δ'))
Δ_factors = C.uplo == 'L' ? tril!(collect(Δ)) : triu!(collect(Δ'))
return ((uplo=nothing, info=nothing, factors=Δ_factors),)
end
end
Expand Down
11 changes: 11 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,17 @@ end
@test gradtest(A->logdet(cholesky(A' * A + I)), A)
@test gradtest(B->cholesky(Symmetric(B)).U, A * A' + I)
@test gradtest(B->logdet(cholesky(Symmetric(B))), A * A' + I)

@testset "inference" begin
out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :U)))
@inferred pb(out)
out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :L)))
@inferred pb(out)
out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :U)))
@inferred pb(out)
out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :L)))
@inferred pb(out)
end
end
@testset "cholesky - scalar" begin
rng = MersenneTwister(123456)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Zygote, Test
using Zygote: gradient, ZygoteRuleConfig
using Zygote: gradient, ZygoteRuleConfig, _pullback, Context
using CUDA
using CUDA: has_cuda

Expand Down