Skip to content

Commit 751e138

Browse files
committed
fix David's comments
1 parent 79e7e03 commit 751e138

File tree

6 files changed

+107
-71
lines changed

6 files changed

+107
-71
lines changed

src/DistributionsAD.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using PDMats,
1212

1313
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
1414
TrackedVecOrMat, track, @grad, data
15+
using SpecialFunctions: logabsgamma, digamma
1516
using ZygoteRules: ZygoteRules, @adjoint, pullback
1617
using LinearAlgebra: copytri!
1718
using Distributions: AbstractMvLogNormal,
@@ -39,15 +40,15 @@ export TuringScalMvNormal,
3940
TuringPoissonBinomial,
4041
TuringWishart,
4142
TuringInverseWishart,
42-
ArrayDist,
43-
FillDist
43+
arraydist,
44+
filldist
4445

4546
include("common.jl")
4647
include("univariate.jl")
4748
include("multivariate.jl")
4849
include("matrixvariate.jl")
4950
include("flatten.jl")
50-
include("array_dist.jl")
51-
include("multi.jl")
51+
include("arraydist.jl")
52+
include("filldist.jl")
5253

5354
end

src/array_dist.jl renamed to src/arraydist.jl

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,31 @@
11
# Univariate
22

3-
const VectorOfUnivariate{
4-
S <: ValueSupport,
5-
Tdist <: UnivariateDistribution{S},
6-
Tdists <: AbstractVector{Tdist},
7-
} = Distributions.Product{S, Tdist, Tdists}
8-
9-
function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T}
10-
if T <: TrackedReal
11-
init_m = vcat(dists[1].μ)
12-
means = mapreduce(vcat, drop(dists, 1); init = init_m) do d
13-
d.μ
14-
end
15-
init_v = vcat(dists[1].σ^2)
16-
vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d
17-
d.σ^2
18-
end
19-
else
20-
means = [d.μ for d in dists]
21-
vars = [d.σ^2 for d in dists]
22-
end
3+
const VectorOfUnivariate = Distributions.Product
234

5+
function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
6+
means = mean.(dists)
7+
vars = var.(dists)
248
return MvNormal(means, vars)
259
end
26-
function ArrayDist(dists::AbstractVector{<:UnivariateDistribution})
27-
return Distributions.Product(dists)
10+
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
11+
means = vcatmapreduce(mean, dists)
12+
vars = vcatmapreduce(var, dists)
13+
return MvNormal(means, vars)
14+
end
15+
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
16+
return product_distribution(dists)
2817
end
2918
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
30-
return sum(logpdf.(dist.v, x))
19+
return sum(vcatmapreduce(logpdf, dist.v, x))
3120
end
3221
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
33-
# Any other more efficient implementation breaks Zygote
34-
return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)]
22+
# eachcol breaks Zygote, so we need an adjoint
23+
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
3524
end
36-
function Distributions.logpdf(
37-
dist::VectorOfUnivariate,
38-
x::AbstractVector{<:AbstractMatrix{<:Real}},
39-
)
40-
return logpdf.(Ref(dist), x)
25+
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
26+
# Any other more efficient implementation breaks Zygote
27+
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
28+
return pullback(f, dist, x)
4129
end
4230

4331
struct MatrixOfUnivariate{
@@ -48,14 +36,13 @@ struct MatrixOfUnivariate{
4836
dists::Tdists
4937
end
5038
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
51-
function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution})
39+
function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
5240
return MatrixOfUnivariate(dists)
5341
end
5442
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
5543
# Broadcasting here breaks Tracker for some reason
56-
return sum(zip(dist.dists, x)) do (dist, x)
57-
logpdf(dist, x)
58-
end
44+
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
45+
return sum(vcatmapreduce(logpdf, dist.dists, x))
5946
end
6047
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
6148
return rand.(Ref(rng), dist.dists)
@@ -72,17 +59,16 @@ struct VectorOfMultivariate{
7259
end
7360
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
7461
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
75-
function ArrayDist(dists::AbstractVector{<:MultivariateDistribution})
62+
function arraydist(dists::AbstractVector{<:MultivariateDistribution})
7663
return VectorOfMultivariate(dists)
7764
end
7865
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
79-
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
66+
# eachcol breaks Zygote, so we define an adjoint
67+
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
8068
end
81-
function Distributions.logpdf(
82-
dist::VectorOfMultivariate,
83-
x::AbstractVector{<:AbstractVector{<:Real}},
84-
)
85-
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
69+
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
70+
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
71+
return pullback(f, dist, x)
8672
end
8773
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
8874
init = reshape(rand(rng, dist.dists[1]), :, 1)

src/common.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
Base.one(::Irrational) = 1
44

5+
function vcatmapreduce(f, args...)
6+
init = vcat(f(first.(args)...,))
7+
zipped_args = zip(args...,)
8+
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
9+
f(zarg...,)
10+
end
11+
end
12+
@adjoint function vcatmapreduce(f, args...)
13+
g(f, args...) = f.(args...,)
14+
return pullback(g, f, args...)
15+
end
16+
517
function Base.fill(
618
value::TrackedReal,
719
dims::Vararg{Union{Integer, AbstractUnitRange}},
@@ -110,9 +122,12 @@ end
110122

111123
# SpecialFunctions
112124

113-
function SpecialFunctions.logabsgamma(x::TrackedReal)
114-
v = loggamma(x)
115-
return v, sign(data(v))
125+
SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x)
126+
@grad function SpecialFunctions.logabsgamma(x::Real)
127+
return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],)
128+
end
129+
@adjoint function SpecialFunctions.logabsgamma(x::Real)
130+
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
116131
end
117132

118133
# Some Tracker fixes

src/multi.jl renamed to src/filldist.jl

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,58 @@ const FillVectorOfUnivariate{
66
Tdists <: Fill{T, 1},
77
} = VectorOfUnivariate{S, T, Tdists}
88

9-
function FillDist(dist::UnivariateDistribution, N::Int)
10-
return Product(Fill(dist, N))
9+
function filldist(dist::UnivariateDistribution, N::Int)
10+
return product_distribution(Fill(dist, N))
1111
end
12-
FillDist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ)
12+
filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ)
13+
1314
function Distributions.logpdf(
1415
dist::FillVectorOfUnivariate,
1516
x::AbstractVector{<:Real},
1617
)
17-
return _flat_logpdf(dist.v.value, x)
18+
return _logpdf(dist, x)
1819
end
1920
function Distributions.logpdf(
2021
dist::FillVectorOfUnivariate,
2122
x::AbstractMatrix{<:Real},
23+
)
24+
return _logpdf(dist, x)
25+
end
26+
@adjoint function Distributions.logpdf(
27+
dist::FillVectorOfUnivariate,
28+
x::AbstractMatrix{<:Real},
29+
)
30+
return pullback(_logpdf, dist, x)
31+
end
32+
33+
function _logpdf(
34+
dist::FillVectorOfUnivariate,
35+
x::AbstractVector{<:Real},
36+
)
37+
return _flat_logpdf(dist.v.value, x)
38+
end
39+
function _logpdf(
40+
dist::FillVectorOfUnivariate,
41+
x::AbstractMatrix{<:Real},
2242
)
2343
return _flat_logpdf_mat(dist.v.value, x)
2444
end
45+
2546
function _flat_logpdf(dist, x)
2647
if toflatten(dist)
2748
f, args = flatten(dist)
2849
return sum(f.(args..., x))
2950
else
30-
return sum(logpdf.(dist, x))
51+
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
3152
end
3253
end
3354
function _flat_logpdf_mat(dist, x)
3455
if toflatten(dist)
3556
f, args = flatten(dist)
3657
return vec(sum(f.(args..., x), dims = 1))
3758
else
38-
return vec(sum(logpdf.(dist, x), dims = 1))
59+
temp = vcatmapreduce(x -> logpdf(dist, x), x)
60+
return vec(sum(reshape(temp, size(x)), dims = 1))
3961
end
4062
end
4163

@@ -45,7 +67,7 @@ const FillMatrixOfUnivariate{
4567
Tdists <: Fill{T, 2},
4668
} = MatrixOfUnivariate{S, T, Tdists}
4769

48-
function FillDist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
70+
function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
4971
return MatrixOfUnivariate(Fill(dist, N1, N2))
5072
end
5173
function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})
@@ -63,12 +85,24 @@ const FillVectorOfMultivariate{
6385
Tdists <: Fill{T, 1},
6486
} = VectorOfMultivariate{S, T, Tdists}
6587

66-
function FillDist(dist::MultivariateDistribution, N::Int)
88+
function filldist(dist::MultivariateDistribution, N::Int)
6789
return VectorOfMultivariate(Fill(dist, N))
6890
end
6991
function Distributions.logpdf(
7092
dist::FillVectorOfMultivariate,
7193
x::AbstractMatrix{<:Real},
94+
)
95+
return _logpdf(dist, x)
96+
end
97+
@adjoint function Distributions.logpdf(
98+
dist::FillVectorOfMultivariate,
99+
x::AbstractMatrix{<:Real},
100+
)
101+
return pullback(_logpdf, dist, x)
102+
end
103+
function _logpdf(
104+
dist::FillVectorOfMultivariate,
105+
x::AbstractMatrix{<:Real},
72106
)
73107
return sum(logpdf(dist.dists.value, x))
74108
end

src/multivariate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function simplex_logpdf(alpha, lmnB, x::AbstractVector)
5959
sum((alpha .- 1) .* log.(x)) - lmnB
6060
end
6161
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
62-
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])))
62+
init = vcat(sum((alpha .- 1) .* log.(view(x, :, 1))))
6363
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
6464
sum((alpha .- 1) .* log.(c)) - lmnB
6565
end

test/distributions.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ separator()
159159
@testset "Multivariate discrete distributions" begin
160160
test_head("Testing: Multivariate discrete distributions")
161161
mult_disc_dists = [
162-
DistSpec(:((p) -> FillDist(Bernoulli(p), dim)), (0.45,), fill(1, dim)),
163-
DistSpec(:((p) -> ArrayDist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)),
162+
DistSpec(:((p) -> filldist(Bernoulli(p), dim)), (0.45,), fill(1, dim)),
163+
DistSpec(:((p) -> arraydist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)),
164164
DistSpec(:((p) -> Multinomial(2, p / sum(p))), (fill(0.5, 2),), [2, 0]),
165165
]
166166
for d in mult_disc_dists
@@ -176,10 +176,10 @@ separator()
176176
test_head("Testing: Multivariate continuous distributions")
177177
mult_cont_dists = [
178178
# Vector case
179-
DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim)),
180-
DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim)),
181-
DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec),
182-
DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec),
179+
DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim)),
180+
DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim)),
181+
DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec),
182+
DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec),
183183
DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec),
184184
DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec),
185185
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
@@ -198,10 +198,10 @@ separator()
198198
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec),
199199
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec),
200200
# Matrix case
201-
DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim, dim)),
202-
DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim, dim)),
203-
DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat),
204-
DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat),
201+
DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim, dim)),
202+
DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim, dim)),
203+
DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat),
204+
DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat),
205205
DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat),
206206
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
207207
DistSpec(:MvNormal, (mean, cov_num), norm_val_mat),
@@ -254,10 +254,10 @@ separator()
254254
@testset "Matrix-variate continuous distributions" begin
255255
test_head("Testing: Matrix-variate continuous distributions")
256256
matrix_cont_dists = [
257-
DistSpec(:(() -> FillDist(Beta(), dim, dim)), (), fill(0.5, dim, dim)),
258-
DistSpec(:(() -> ArrayDist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)),
259-
DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat),
260-
DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat),
257+
DistSpec(:(() -> filldist(Beta(), dim, dim)), (), fill(0.5, dim, dim)),
258+
DistSpec(:(() -> arraydist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)),
259+
DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat),
260+
DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat),
261261
DistSpec(:((n1, n2)->MatrixBeta(dim, n1, n2)), (dim, dim), beta_mat),
262262
DistSpec(:Wishart, (dim, cov_mat), cov_mat),
263263
DistSpec(:InverseWishart, (dim, cov_mat), cov_mat),

0 commit comments

Comments
 (0)