Skip to content

Commit b296d39

Browse files
authored
Merge pull request #19 from TuringLang/mt/array_dist_and_multi
Implement FillDist and ArrayDist
2 parents c93d8e7 + 105e5e8 commit b296d39

File tree

11 files changed

+440
-81
lines changed

11 files changed

+440
-81
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ version = "0.3.2"
44

55
[deps]
66
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
7+
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
810
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1012
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
@@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1820

1921
[compat]
2022
Combinatorics = "0.7"
23+
DiffRules = "0.1, 1.0"
2124
Distributions = "0.22"
25+
FillArrays = "0.8"
26+
FiniteDifferences = "0.9"
2227
ForwardDiff = "0.10.6"
2328
PDMats = "0.9"
2429
SpecialFunctions = "0.8, 0.9, 0.10"
30+
StatsBase = "0.32"
2531
StatsFuns = "0.8, 0.9"
2632
Tracker = "0.2.5"
2733
Zygote = "0.4.7"

src/DistributionsAD.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ using PDMats,
1111
StatsFuns
1212

1313
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
14-
TrackedVecOrMat, track, data
15-
using ZygoteRules: ZygoteRules, pullback
14+
TrackedVecOrMat, track, @grad, data
15+
using SpecialFunctions: logabsgamma, digamma
16+
using ZygoteRules: ZygoteRules, @adjoint, pullback
1617
using LinearAlgebra: copytri!
1718
using Distributions: AbstractMvLogNormal,
1819
ContinuousMultivariateDistribution
20+
using DiffRules, SpecialFunctions, FillArrays
21+
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
22+
using Base.Iterators: drop
1923

2024
import StatsFuns: logsumexp,
2125
binomlogpdf,
@@ -35,11 +39,16 @@ export TuringScalMvNormal,
3539
TuringMvLogNormal,
3640
TuringPoissonBinomial,
3741
TuringWishart,
38-
TuringInverseWishart
42+
TuringInverseWishart,
43+
arraydist,
44+
filldist
3945

4046
include("common.jl")
4147
include("univariate.jl")
4248
include("multivariate.jl")
4349
include("matrixvariate.jl")
50+
include("flatten.jl")
51+
include("arraydist.jl")
52+
include("filldist.jl")
4453

4554
end

src/arraydist.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Univariate
2+
3+
const VectorOfUnivariate = Distributions.Product
4+
5+
function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
6+
means = mean.(dists)
7+
vars = var.(dists)
8+
return MvNormal(means, vars)
9+
end
10+
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
11+
means = vcatmapreduce(mean, dists)
12+
vars = vcatmapreduce(var, dists)
13+
return MvNormal(means, vars)
14+
end
15+
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
16+
return product_distribution(dists)
17+
end
18+
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
19+
return sum(vcatmapreduce(logpdf, dist.v, x))
20+
end
21+
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
22+
# eachcol breaks Zygote, so we need an adjoint
23+
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
24+
end
25+
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
26+
# Any other more efficient implementation breaks Zygote
27+
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
28+
return pullback(f, dist, x)
29+
end
30+
31+
struct MatrixOfUnivariate{
32+
S <: ValueSupport,
33+
Tdist <: UnivariateDistribution{S},
34+
Tdists <: AbstractMatrix{Tdist},
35+
} <: MatrixDistribution{S}
36+
dists::Tdists
37+
end
38+
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
39+
function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
40+
return MatrixOfUnivariate(dists)
41+
end
42+
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
43+
# Broadcasting here breaks Tracker for some reason
44+
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
45+
return sum(vcatmapreduce(logpdf, dist.dists, x))
46+
end
47+
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
48+
return rand.(Ref(rng), dist.dists)
49+
end
50+
51+
# Multivariate
52+
53+
struct VectorOfMultivariate{
54+
S <: ValueSupport,
55+
Tdist <: MultivariateDistribution{S},
56+
Tdists <: AbstractVector{Tdist},
57+
} <: MatrixDistribution{S}
58+
dists::Tdists
59+
end
60+
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
61+
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
62+
function arraydist(dists::AbstractVector{<:MultivariateDistribution})
63+
return VectorOfMultivariate(dists)
64+
end
65+
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
66+
# eachcol breaks Zygote, so we define an adjoint
67+
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
68+
end
69+
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
70+
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
71+
return pullback(f, dist, x)
72+
end
73+
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
74+
init = reshape(rand(rng, dist.dists[1]), :, 1)
75+
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
76+
end

src/common.jl

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
11
## Generic ##
22

3+
if VERSION < v"1.1"
4+
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
5+
end
6+
7+
Base.one(::Irrational) = true
8+
9+
function vcatmapreduce(f, args...)
10+
init = vcat(f(first.(args)...,))
11+
zipped_args = zip(args...,)
12+
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
13+
f(zarg...,)
14+
end
15+
end
16+
@adjoint function vcatmapreduce(f, args...)
17+
g(f, args...) = f.(args...,)
18+
return pullback(g, f, args...)
19+
end
20+
321
function Base.fill(
422
value::TrackedReal,
523
dims::Vararg{Union{Integer, AbstractUnitRange}},
624
)
725
return track(fill, value, dims...)
826
end
9-
Tracker.@grad function Base.fill(value::Real, dims...)
27+
@grad function Base.fill(value::Real, dims...)
1028
return fill(data(value), dims...), function(Δ)
1129
size(Δ) dims && error("Dimension mismatch")
1230
return (sum(Δ), map(_->nothing, dims)...)
@@ -16,15 +34,15 @@ end
1634
## StatsFuns ##
1735

1836
logsumexp(x::TrackedArray) = track(logsumexp, x)
19-
Tracker.@grad function logsumexp(x::TrackedArray)
37+
@grad function logsumexp(x::TrackedArray)
2038
lse = logsumexp(data(x))
2139
return lse, Δ ->.* exp.(x .- lse),)
2240
end
2341

2442
## Linear algebra ##
2543

2644
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
27-
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
45+
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
2846
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
2947
end
3048

@@ -39,27 +57,27 @@ function turing_chol(A::AbstractMatrix, check)
3957
(chol.factors, chol.info)
4058
end
4159
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
42-
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
60+
@grad function turing_chol(A::AbstractMatrix, check)
4361
C, back = pullback(unsafe_cholesky, data(A), data(check))
4462
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
4563
end
4664

4765
unsafe_cholesky(x, check) = cholesky(x, check=check)
48-
ZygoteRules.@adjoint function unsafe_cholesky::Real, check)
66+
@adjoint function unsafe_cholesky::Real, check)
4967
C = cholesky(Σ; check=check)
5068
return C, function::NamedTuple)
5169
issuccess(C) || return (zero(Σ), nothing)
5270
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
5371
end
5472
end
55-
ZygoteRules.@adjoint function unsafe_cholesky::Diagonal, check)
73+
@adjoint function unsafe_cholesky::Diagonal, check)
5674
C = cholesky(Σ; check=check)
5775
return C, function::NamedTuple)
5876
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
5977
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
6078
end
6179
end
62-
ZygoteRules.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
80+
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
6381
C = cholesky(Σ; check=check)
6482
return C, function::NamedTuple)
6583
issuccess(C) || return (zero.factors), nothing)
@@ -78,7 +96,7 @@ end
7896
# Specialised logdet for cholesky to target the triangle directly.
7997
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
8098
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
81-
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
99+
@grad function logdet_chol_tri(U::AbstractMatrix)
82100
U_data = data(U)
83101
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
84102
end
@@ -88,6 +106,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
88106
end
89107

90108
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
109+
91110
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
92111
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
93112
return track(zygote_ldiv, A, B)
@@ -96,11 +115,49 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
96115
return track(zygote_ldiv, A, B)
97116
end
98117
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
99-
Tracker.@grad function zygote_ldiv(A, B)
118+
@grad function zygote_ldiv(A, B)
100119
Y, back = pullback(\, data(A), data(B))
101120
return Y, Δ->back(data(Δ))
102121
end
103122

104123
function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
105124
return (a.U \ (a.U' \ b))
106125
end
126+
127+
# SpecialFunctions
128+
129+
SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x)
130+
@grad function SpecialFunctions.logabsgamma(x::Real)
131+
return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],)
132+
end
133+
@adjoint function SpecialFunctions.logabsgamma(x::Real)
134+
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
135+
end
136+
137+
# Some Tracker fixes
138+
139+
for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
140+
if :TrackedReal in c
141+
cnames = map(_ -> gensym(), c)
142+
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
143+
track($f, $(cnames...), x, xs...)
144+
end
145+
end
146+
@grad function vcat(x::Real)
147+
vcat(data(x)), (Δ) -> (Δ[1],)
148+
end
149+
@grad function vcat(x1::Real, x2::Real)
150+
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
151+
end
152+
@grad function vcat(x1::AbstractVector, x2::Real)
153+
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
154+
end
155+
156+
# Zygote fill has issues with non-numbers
157+
158+
@adjoint function fill(x::T, dims...) where {T}
159+
function zfill(x, dims...,)
160+
return reshape([x for i in 1:prod(dims)], dims)
161+
end
162+
pullback(zfill, x, dims...)
163+
end

src/filldist.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Univariate
2+
3+
const FillVectorOfUnivariate{
4+
S <: ValueSupport,
5+
T <: UnivariateDistribution{S},
6+
Tdists <: Fill{T, 1},
7+
} = VectorOfUnivariate{S, T, Tdists}
8+
9+
function filldist(dist::UnivariateDistribution, N::Int)
10+
return product_distribution(Fill(dist, N))
11+
end
12+
filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ)
13+
14+
function Distributions.logpdf(
15+
dist::FillVectorOfUnivariate,
16+
x::AbstractVector{<:Real},
17+
)
18+
return _logpdf(dist, x)
19+
end
20+
function Distributions.logpdf(
21+
dist::FillVectorOfUnivariate,
22+
x::AbstractMatrix{<:Real},
23+
)
24+
return _logpdf(dist, x)
25+
end
26+
@adjoint function Distributions.logpdf(
27+
dist::FillVectorOfUnivariate,
28+
x::AbstractMatrix{<:Real},
29+
)
30+
return pullback(_logpdf, dist, x)
31+
end
32+
33+
function _logpdf(
34+
dist::FillVectorOfUnivariate,
35+
x::AbstractVector{<:Real},
36+
)
37+
return _flat_logpdf(dist.v.value, x)
38+
end
39+
function _logpdf(
40+
dist::FillVectorOfUnivariate,
41+
x::AbstractMatrix{<:Real},
42+
)
43+
return _flat_logpdf_mat(dist.v.value, x)
44+
end
45+
46+
function _flat_logpdf(dist, x)
47+
if toflatten(dist)
48+
f, args = flatten(dist)
49+
return sum(f.(args..., x))
50+
else
51+
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
52+
end
53+
end
54+
function _flat_logpdf_mat(dist, x)
55+
if toflatten(dist)
56+
f, args = flatten(dist)
57+
return vec(sum(f.(args..., x), dims = 1))
58+
else
59+
temp = vcatmapreduce(x -> logpdf(dist, x), x)
60+
return vec(sum(reshape(temp, size(x)), dims = 1))
61+
end
62+
end
63+
64+
const FillMatrixOfUnivariate{
65+
S <: ValueSupport,
66+
T <: UnivariateDistribution{S},
67+
Tdists <: Fill{T, 2},
68+
} = MatrixOfUnivariate{S, T, Tdists}
69+
70+
function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
71+
return MatrixOfUnivariate(Fill(dist, N1, N2))
72+
end
73+
function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})
74+
return _flat_logpdf(dist.dists.value, x)
75+
end
76+
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate)
77+
return rand(rng, dist.dists.value, length.(dist.dists.axes))
78+
end
79+
80+
# Multivariate
81+
82+
const FillVectorOfMultivariate{
83+
S <: ValueSupport,
84+
T <: MultivariateDistribution{S},
85+
Tdists <: Fill{T, 1},
86+
} = VectorOfMultivariate{S, T, Tdists}
87+
88+
function filldist(dist::MultivariateDistribution, N::Int)
89+
return VectorOfMultivariate(Fill(dist, N))
90+
end
91+
function Distributions.logpdf(
92+
dist::FillVectorOfMultivariate,
93+
x::AbstractMatrix{<:Real},
94+
)
95+
return _logpdf(dist, x)
96+
end
97+
@adjoint function Distributions.logpdf(
98+
dist::FillVectorOfMultivariate,
99+
x::AbstractMatrix{<:Real},
100+
)
101+
return pullback(_logpdf, dist, x)
102+
end
103+
function _logpdf(
104+
dist::FillVectorOfMultivariate,
105+
x::AbstractMatrix{<:Real},
106+
)
107+
return sum(logpdf(dist.dists.value, x))
108+
end
109+
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
110+
return rand(rng, dist.dists.value, length.(dist.dists.axes))
111+
end

0 commit comments

Comments
 (0)