Skip to content

Commit 4af5d8d

Browse files
authored
Merge pull request #495 from chentoast/dirichlet
feat: add dirichlet distribution
2 parents b08cae9 + 7d111d5 commit 4af5d8d

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
struct Dirichlet <: Distribution{Float64} end
2+
3+
"""
4+
Dirichlet(alpha::Vector{Float64})
5+
6+
Sample a simplex Vector{Float64} from a Dirichlet distribution.
7+
"""
8+
const dirichlet = Dirichlet()
9+
10+
function logpdf(::Dirichlet, x::AbstractVector{T}, alpha::AbstractVector{U}) where {T <: Real, U <: Real}
11+
if length(x) == length(alpha) && isapprox(sum(x), 1) && all(x .>= 0) && all(alpha .>= 0)
12+
ll = sum((a_i - 1) * log(x_i) for (a_i, x_i) in zip(alpha, x))
13+
ll -= sum(loggamma.(alpha)) - loggamma(sum(alpha))
14+
ll
15+
else
16+
-Inf
17+
end
18+
end
19+
20+
function logpdf_grad(::Dirichlet, x::AbstractVector{T}, alpha::AbstractVector{U}) where {T <: Real, U <: Real}
21+
if length(x) == length(alpha) && isapprox(sum(x), 1) && all(x .>= 0) && all(alpha .>= 0)
22+
deriv_x = (alpha .- 1) ./ x
23+
deriv_alpha = log.(x) .- digamma.(alpha) .+ digamma(sum(alpha))
24+
(deriv_x, deriv_alpha)
25+
else
26+
(zero(x), zero(alpha))
27+
end
28+
end
29+
30+
function random(::Dirichlet, alpha::AbstractVector{T}) where {T <: Real}
31+
rand(Distributions.Dirichlet(alpha))
32+
end
33+
34+
is_discrete(::Dirichlet) = false
35+
36+
(::Dirichlet)(alpha) = random(Dirichlet(), alpha)
37+
38+
has_output_grad(::Dirichlet) = true
39+
has_argument_grads(::Dirichlet) = (true,)
40+
41+
export dirichlet
42+

src/modeling_library/distributions/distributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ include("beta.jl")
44
include("binom.jl")
55
include("categorical.jl")
66
include("cauchy.jl")
7+
include("dirichlet.jl")
78
include("exponential.jl")
89
include("gamma.jl")
910
include("geometric.jl")

test/modeling_library/distributions.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import DataStructures: OrderedDict
2+
import LinearAlgebra: diagm
23

34
@testset "bernoulli" begin
45

@@ -233,6 +234,61 @@ end
233234
@test isapprox(actual[3][2, 2], finite_diff_mat_sym(f, args, 3, 2, 2, dx))
234235
end
235236

237+
@testset "dirichlet" begin
238+
x = dirichlet([1., 1., 1., 1.])
239+
@test length(x) == 4
240+
@test isapprox(sum(x), 1.)
241+
242+
# bounds checking
243+
@test logpdf(dirichlet, [0., 0], [1., 1.]) == -Inf
244+
@test logpdf(dirichlet, [1., 1.], [1., 1.]) == -Inf
245+
@test logpdf(dirichlet, [2., -1], [1., 1.]) == -Inf
246+
@test logpdf(dirichlet, [.5, .5], [-1., 1.]) == -Inf
247+
@test logpdf(dirichlet, [.5, .5], [-1., 1.]) == -Inf
248+
@test logpdf(dirichlet, [0., 1], [1., 1.]) != -Inf
249+
250+
@test isapprox(logpdf(dirichlet, [.01, .99], [2., 2.]),
251+
Distributions.logpdf(Distributions.Dirichlet([2., 2.]), [.01, .99]))
252+
@test isapprox(logpdf(dirichlet, [.01, .99], [1., 4.]),
253+
Distributions.logpdf(Distributions.Dirichlet([1., 4.]), [.01, .99]))
254+
@test isapprox(logpdf(dirichlet, [.01, .99], [.01, .01]),
255+
Distributions.logpdf(Distributions.Dirichlet([.01, .01]), [.01, .99]))
256+
257+
# for d > 2
258+
@test isapprox(logpdf(dirichlet, [.2, .2, .6], [2., 2., 4.]),
259+
Distributions.logpdf(Distributions.Dirichlet([2., 2., 4.]), [.2, .2, .6]))
260+
261+
function softmax(x)
262+
exp.(x) / sum(exp.(x))
263+
end
264+
265+
function softmax_grad(x)
266+
diagm(x) .- (x .* x')
267+
end
268+
269+
f = (x, alpha) -> logpdf(dirichlet, x, alpha)
270+
f_normalized = (x, alpha) -> logpdf(dirichlet, softmax(x), alpha)
271+
272+
args = ([0., 0., 0., 0.], [1., 2., 3., 3.])
273+
normalized_args = ([.25, .25, .25, .25], [1., 2., 3., 3.])
274+
275+
actual = logpdf_grad(dirichlet, normalized_args...)
276+
277+
# gradients with respect to x
278+
actual_x_grad = actual[1]' * softmax_grad(normalized_args[1])
279+
280+
@test isapprox(actual_x_grad[1], finite_diff_vec(f_normalized, args, 1, 1, dx))
281+
@test isapprox(actual_x_grad[2], finite_diff_vec(f_normalized, args, 1, 2, dx))
282+
@test isapprox(actual_x_grad[3], finite_diff_vec(f_normalized, args, 1, 3, dx))
283+
@test isapprox(actual_x_grad[4], finite_diff_vec(f_normalized, args, 1, 4, dx))
284+
285+
# gradients with respect to alpha
286+
@test isapprox(actual[2][1], finite_diff_vec(f, normalized_args, 2, 1, dx))
287+
@test isapprox(actual[2][2], finite_diff_vec(f, normalized_args, 2, 2, dx))
288+
@test isapprox(actual[2][3], finite_diff_vec(f, normalized_args, 2, 3, dx))
289+
@test isapprox(actual[2][4], finite_diff_vec(f, normalized_args, 2, 4, dx))
290+
end
291+
236292
@testset "uniform" begin
237293

238294
# random

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Gen
22
using Test
33
import Random
4+
import Distributions
45

56
"""
67
Compute a numerical partial derivative of `f` with respect to the `i`th

0 commit comments

Comments
 (0)