Skip to content

Commit c01566e

Browse files
authored
Merge pull request #42 from TuringLang/mt/fix_linalg
Some linear algebra fixes
2 parents f32c534 + 7dbba19 commit c01566e

File tree

5 files changed

+41
-45
lines changed

5 files changed

+41
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ SpecialFunctions = "0.8, 0.9, 0.10"
3232
StatsBase = "0.32"
3333
StatsFuns = "0.8, 0.9"
3434
Tracker = "0.2.5"
35-
Zygote = "0.4.7"
35+
Zygote = "0.4.10"
3636
ZygoteRules = "0.2"
3737
julia = "1"
3838

src/DistributionsAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
1515
TrackedVecOrMat, track, @grad, data
1616
using SpecialFunctions: logabsgamma, digamma
1717
using ZygoteRules: ZygoteRules, @adjoint, pullback
18-
using LinearAlgebra: copytri!
18+
using LinearAlgebra: copytri!, AbstractTriangular
1919
using Distributions: AbstractMvLogNormal,
2020
ContinuousMultivariateDistribution
2121
using DiffRules, SpecialFunctions, FillArrays

src/common.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,42 @@ end
4040

4141
## Linear algebra ##
4242

43-
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
44-
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
45-
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
43+
# Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767
44+
45+
upper(A::AbstractMatrix) = UpperTriangular(A)
46+
lower(A::AbstractMatrix) = LowerTriangular(A)
47+
function upper(C::Cholesky)
48+
if C.uplo == 'U'
49+
return upper(C.factors)
50+
else
51+
return copy(lower(C.factors)')
52+
end
53+
end
54+
function lower(C::Cholesky)
55+
if C.uplo == 'U'
56+
return copy(upper(C.factors)')
57+
else
58+
return lower(C.factors)
59+
end
60+
end
61+
62+
LinearAlgebra.LowerTriangular(A::TrackedMatrix) = lower(A)
63+
lower(A::TrackedMatrix) = track(lower, A)
64+
@grad lower(A) = lower(Tracker.data(A)), ∇ -> (lower(∇),)
65+
66+
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = upper(A)
67+
upper(A::TrackedMatrix) = track(upper, A)
68+
@grad upper(A) = upper(Tracker.data(A)), ∇ -> (upper(∇),)
69+
70+
function Base.copy(
71+
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
72+
) where {T <: Real}
73+
return track(copy, A)
74+
end
75+
@grad function Base.copy(
76+
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
77+
) where {T <: Real}
78+
return copy(data(A)), ∇ -> (copy(∇),)
4679
end
4780

4881
function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
@@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check)
5790
end
5891
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
5992
@grad function turing_chol(A::AbstractMatrix, check)
60-
C, back = pullback(unsafe_cholesky, data(A), data(check))
93+
C, back = pullback(_turing_chol, data(A), data(check))
6194
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
6295
end
63-
64-
unsafe_cholesky(x, check) = cholesky(x, check=check)
65-
@adjoint function unsafe_cholesky::Real, check)
66-
C = cholesky(Σ; check=check)
67-
return C, function::NamedTuple)
68-
issuccess(C) || return (zero(Σ), nothing)
69-
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
70-
end
71-
end
72-
@adjoint function unsafe_cholesky::Diagonal, check)
73-
C = cholesky(Σ; check=check)
74-
return C, function::NamedTuple)
75-
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
76-
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
77-
end
78-
end
79-
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
80-
C = cholesky(Σ; check=check)
81-
return C, function::NamedTuple)
82-
issuccess(C) || return (zero.factors), nothing)
83-
U, Ū = C.U, Δ.factors
84-
Σ̄ =* U'
85-
Σ̄ = copytri!(Σ̄, 'U')
86-
Σ̄ = ldiv!(U, Σ̄)
87-
BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
88-
@inbounds for n in diagind(Σ̄)
89-
Σ̄[n] /= 2
90-
end
91-
return (UpperTriangular(Σ̄), nothing)
92-
end
93-
end
96+
_turing_chol(x, check) = cholesky(x, check=check)
9497

9598
# Specialised logdet for cholesky to target the triangle directly.
9699
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])

test/others.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
using StatsBase: entropy
22

33
if get_stage() in ("Others", "all")
4-
@testset "unsafe_cholesky" begin
5-
A = rand(3, 3); A = A + A' + 3I
6-
@test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A))
7-
@test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false))
8-
@test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true)
9-
end
10-
114
@testset "TuringWishart" begin
125
dim = 3
136
A = Matrix{Float64}(I, dim, dim)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DistributionsAD, Test, LinearAlgebra, Combinatorics
44
using ForwardDiff: Dual
55
using StatsFuns: binomlogpdf, logsumexp
66
const FDM = FiniteDifferences
7-
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform, unsafe_cholesky
7+
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform
88
using Distributions: meanlogdet
99

1010
include("test_utils.jl")

0 commit comments

Comments
 (0)