Skip to content

Commit 9806ec3

Browse files
authored
Fix and test TuringDirichlet constructors (#152)
* Fix and test `TuringDirichlet` constructors * Fix typo * Import `TuringDirichlet`
1 parent d6aaa64 commit 9806ec3

File tree

6 files changed

+83
-44
lines changed

6 files changed

+83
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.18"
3+
version = "0.6.19"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/multivariate.jl

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,54 @@
11
## Dirichlet ##
22

3-
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
3+
struct TuringDirichlet{T<:Real,TV<:AbstractVector,S<:Real} <: ContinuousMultivariateDistribution
44
alpha::TV
55
alpha0::T
6-
lmnB::T
7-
end
8-
Base.length(d::TuringDirichlet) = length(d.alpha)
9-
function check(alpha)
10-
all(ai -> ai > 0, alpha) ||
11-
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
12-
end
13-
14-
function Distributions._rand!(rng::Random.AbstractRNG,
15-
d::TuringDirichlet,
16-
x::AbstractVector{<:Real})
17-
s = 0.0
18-
n = length(x)
19-
α = d.alpha
20-
for i in 1:n
21-
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
22-
end
23-
Distributions.multiply!(x, inv(s)) # this returns x
6+
lmnB::S
247
end
258

269
function TuringDirichlet(alpha::AbstractVector)
27-
check(alpha)
10+
all(ai -> ai > 0, alpha) ||
11+
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
12+
2813
alpha0 = sum(alpha)
2914
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
30-
T = promote_type(typeof(alpha0), typeof(lmnB))
31-
TV = typeof(alpha)
32-
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
33-
end
3415

35-
function TuringDirichlet(d::Integer, alpha::Real)
36-
alpha0 = alpha * d
37-
_alpha = fill(alpha, d)
38-
lmnB = loggamma(alpha) * d - loggamma(alpha0)
39-
T = promote_type(typeof(alpha0), typeof(lmnB))
40-
TV = typeof(_alpha)
41-
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
42-
end
43-
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
44-
TuringDirichlet(float.(alpha))
16+
return TuringDirichlet(alpha, alpha0, lmnB)
4517
end
46-
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
18+
TuringDirichlet(d::Integer, alpha::Real) = TuringDirichlet(Fill(alpha, d))
4719

20+
# TODO: remove?
21+
TuringDirichlet(alpha::AbstractVector{<:Integer}) = TuringDirichlet(float.(alpha))
22+
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha))
23+
24+
# TODO: remove and use `Dirichlet` only for `Tracker.TrackedVector`
4825
Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)
4926

27+
TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB)
28+
29+
Base.length(d::TuringDirichlet) = length(d.alpha)
30+
31+
# copied from Distributions
32+
# TODO: remove and use `Dirichlet`?
33+
function Distributions._rand!(
34+
rng::Random.AbstractRNG,
35+
d::TuringDirichlet,
36+
x::AbstractVector{<:Real},
37+
)
38+
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
39+
x[i] = rand(rng, Gamma(αi))
40+
end
41+
Distributions.multiply!(x, inv(sum(x))) # this returns x
42+
end
43+
function Distributions._rand!(
44+
rng::AbstractRNG,
45+
d::TuringDirichlet{<:Real,<:FillArrays.AbstractFill},
46+
x::AbstractVector{<:Real}
47+
)
48+
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
49+
Distributions.multiply!(x, inv(sum(x))) # this returns x
50+
end
51+
5052
function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real})
5153
return simplex_logpdf(d.alpha, d.lmnB, x)
5254
end

src/reversediff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,13 @@ Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha)
260260
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
261261

262262
function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal})
263-
return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
263+
return _logpdf(TuringDirichlet(d), x)
264264
end
265265
function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
266-
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
266+
return logpdf(TuringDirichlet(d), x)
267267
end
268268
function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
269-
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
269+
return loglikelihood(TuringDirichlet(d), x)
270270
end
271271

272272
# default definition of `loglikelihood` yields gradients of zero?!

src/tracker.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
371371
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
372372

373373
function Distributions._logpdf(d::Dirichlet, x::TrackedVector{<:Real})
374-
return Distributions._logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
374+
return Distributions._logpdf(TuringDirichlet(d), x)
375375
end
376376
function Distributions.logpdf(d::Dirichlet, x::TrackedMatrix{<:Real})
377-
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
377+
return logpdf(TuringDirichlet(d), x)
378378
end
379379
function Distributions.loglikelihood(d::Dirichlet, x::TrackedMatrix{<:Real})
380-
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
380+
return loglikelihood(TuringDirichlet(d), x)
381381
end
382382

383383
# Fix ambiguities
@@ -615,4 +615,3 @@ Distributions.InverseWishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = Turin
615615
Distributions.InverseWishart(df::Real, S::TrackedMatrix) = TuringInverseWishart(df, S)
616616
Distributions.InverseWishart(df::TrackedReal, S::TrackedMatrix) = TuringInverseWishart(df, S)
617617
Distributions.InverseWishart(df::TrackedReal, S::AbstractPDMat{<:TrackedReal}) = TuringInverseWishart(df, S)
618-

test/others.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,4 +298,42 @@
298298
end
299299
end
300300
end
301+
302+
@testset "TuringDirichlet" begin
303+
dim = 3
304+
n = 4
305+
for alpha in (2, rand())
306+
d1 = TuringDirichlet(dim, alpha)
307+
d2 = Dirichlet(dim, alpha)
308+
d3 = TuringDirichlet(d2)
309+
@test d1.alpha == d2.alpha == d3.alpha
310+
@test d1.alpha0 == d2.alpha0 == d3.alpha0
311+
@test d1.lmnB == d2.lmnB == d3.lmnB
312+
313+
s1 = rand(d1)
314+
@test s1 isa Vector{Float64}
315+
@test length(s1) == dim
316+
317+
s2 = rand(d1, n)
318+
@test s2 isa Matrix{Float64}
319+
@test size(s2) == (dim, n)
320+
end
321+
322+
for alpha in (ones(Int, dim), rand(dim))
323+
d1 = TuringDirichlet(alpha)
324+
d2 = Dirichlet(alpha)
325+
d3 = TuringDirichlet(d2)
326+
@test d1.alpha == d2.alpha == d3.alpha
327+
@test d1.alpha0 == d2.alpha0 == d3.alpha0
328+
@test d1.lmnB == d2.lmnB == d3.lmnB
329+
330+
s1 = rand(d1)
331+
@test s1 isa Vector{Float64}
332+
@test length(s1) == dim
333+
334+
s2 = rand(d1, n)
335+
@test s2 isa Matrix{Float64}
336+
@test size(s2) == (dim, n)
337+
end
338+
end
301339
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Random, LinearAlgebra, Test
1010

1111
using Distributions: meanlogdet
1212
using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal,
13-
TuringPoissonBinomial
13+
TuringPoissonBinomial, TuringDirichlet
1414
using StatsBase: entropy
1515
using StatsFuns: binomlogpdf, logsumexp, logistic
1616

0 commit comments

Comments
 (0)