Skip to content

Commit 8a4f3eb

Browse files
authored
Merge pull request #341 from probcomp/20201209-marcoct-mixture
Add a constructor for mixture distributions
2 parents 38e6571 + b5137b3 commit 8a4f3eb

File tree

8 files changed

+513
-8
lines changed

8 files changed

+513
-8
lines changed

docs/src/ref/distributions.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Probability Distributions
22

3-
Gen provides a library of built-in probability distributions, and two ways of
4-
writing custom distributions, both of which are explained below:
3+
Gen provides a library of built-in probability distributions, and three ways of
4+
defining custom distributions, each of which are explained below:
55

6-
1. The `@dist` constructor, for a distribution that can be expressed as a
6+
1. The [`@dist` constructor](@ref dist_dsl), for a distribution that can be expressed as a
77
simple deterministic transformation (technically, a
88
[pushforward](https://en.wikipedia.org/wiki/Pushforward_measure)) of an
99
existing distribution.
1010

11-
2. An API for defining arbitrary [custom distributions](@ref
11+
2. The [`HeterogeneousMixture`](@ref) and [`HomogeneousMixture`](@ref) constructors
12+
for distributions that are mixtures of other distributions.
13+
14+
3. An API for defining arbitrary [custom distributions](@ref
1215
custom_distributions) in plain Julia code.
1316

1417
## Built-In Distributions
@@ -208,6 +211,14 @@ log(normal(exp(x), exp(x))) :: RND (by rule 6)
208211
log(normal(exp(x), exp(x))) + (x * (2 + 3)) :: RND (by rule 6)
209212
```
210213

214+
## Mixture Distribution Constructors
215+
216+
There are two built-in constructors for defining mixture distributions:
217+
```@docs
218+
HomogeneousMixture
219+
HeterogeneousMixture
220+
```
221+
211222
## Defining New Distributions From Scratch
212223

213224
For distributions that cannot be expressed in the `@dist` DSL, users can define

src/modeling_library/distributions/binom.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ end
2323

2424
has_output_grad(::Binomial) = false
2525
has_argument_grads(::Binomial) = (false, true)
26+
is_discrete(::Binomial) = true
2627

2728
export binom

src/modeling_library/distributions/mvnormal.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import LinearAlgebra
2+
13
struct MultivariateNormal <: Distribution{Vector{Float64}} end
24

35
"""
@@ -9,13 +11,13 @@ const mvnormal = MultivariateNormal()
911

1012
function logpdf(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
1113
cov::AbstractMatrix{V}) where {T <: Real, U <: Real, V <: Real}
12-
dist = Distributions.MvNormal(mu, cov)
14+
dist = Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))
1315
Distributions.logpdf(dist, x)
1416
end
1517

1618
function logpdf_grad(::MultivariateNormal, x::AbstractVector{T}, mu::AbstractVector{U},
1719
cov::AbstractMatrix{V}) where {T <: Real,U <: Real, V <: Real}
18-
dist = Distributions.MvNormal(mu, cov)
20+
dist = Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov))
1921
inv_cov = Distributions.invcov(dist)
2022

2123
x_deriv = Distributions.gradlogpdf(dist, x)
@@ -27,7 +29,7 @@ end
2729

2830
function random(::MultivariateNormal, mu::AbstractVector{U},
2931
cov::AbstractMatrix{V}) where {U <: Real, V <: Real}
30-
rand(Distributions.MvNormal(mu, cov))
32+
rand(Distributions.MvNormal(mu, LinearAlgebra.Symmetric(cov)))
3133
end
3234

3335
(::MultivariateNormal)(mu, cov) = random(MultivariateNormal(), mu, cov)

src/modeling_library/mixture.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
##################################################################
2+
# homogeneous mixture: arbitrary number of the same distribution #
3+
##################################################################
4+
5+
"""
6+
HomogeneousMixture(distribution::Distribution, dims::Vector{Int})
7+
8+
Define a new distribution that is a mixture of some number of instances of single base distributions.
9+
10+
The first argument defines the base distribution of each component in the mixture.
11+
12+
The second argument must have length equal
13+
to the number of arguments taken by the base distribution. A value of 0
14+
at a position in the vector an indicates that the corresponding argument to the
15+
base distribution is a scalar, and integer values of i for i >= 1 indicate that
16+
the corresponding argument is an i-dimensional array.
17+
18+
Example:
19+
20+
```julia
21+
mixture_of_normals = HomogeneousMixture(normal, [0, 0])
22+
```
23+
24+
The resulting distribution (e.g. `mixture_of_normals` above) can then be used like the built-in distribution values like `normal`.
25+
The distribution takes `n+1` arguments where `n` is the number of arguments taken by the base distribution.
26+
The first argument to the distribution is a vector of non-negative mixture weights, which must sum to 1.0.
27+
The remaining arguments to the distribution correspond to the arguments of the base distribution, but have a different type:
28+
If an argument to the base distribution is a scalar of type `T`, then the corresponding argument to the mixture distribution is a `Vector{T}`, where each element of this vector is the argument to the corresponding mixture component.
29+
If an argument to the base distribution is an `Array{T,N}` for some `N`, then the corresponding argument to the mixture distribution is of the form `arr::Array{T,N+1}`, where each slice of the array of the form `arr[:,:,...,i]` is the argument for the `i`th mixture component.
30+
31+
Example:
32+
33+
```julia
34+
mixture_of_normals = HomogeneousMixture(normal, [0, 0])
35+
mixture_of_mvnormals = HomogeneousMixture(mvnormal, [1, 2])
36+
37+
@gen function foo()
38+
# mixture of two normal distributions
39+
# with means -1.0 and 1.0
40+
# and standard deviations 0.1 and 10.0
41+
# the first normal distribution has weight 0.4; the second has weight 0.6
42+
x ~ mixture_of_normals([0.4, 0.6], [-1.0, 1.0], [0.1, 10.0])
43+
44+
# mixture of two multivariate normal distributions
45+
# with means: [0.0, 0.0] and [1.0, 1.0]
46+
# and covariance matrices: [1.0 0.0; 0.0 1.0] and [10.0 0.0; 0.0 10.0]
47+
# the first multivariate normal distribution has weight 0.4;
48+
# the second has weight 0.6
49+
means = [0.0 1.0; 0.0 1.0] # or, cat([0.0, 0.0], [1.0, 1.0], dims=2)
50+
covs = cat([1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0], dims=3)
51+
y ~ mixture_of_mvnormals([0.4, 0.6], means, covs)
52+
end
53+
```
54+
"""
55+
struct HomogeneousMixture{T} <: Distribution{T}
56+
base_dist::Distribution{T}
57+
dims::Vector{Int}
58+
end
59+
60+
(dist::HomogeneousMixture)(args...) = random(dist, args...)
61+
62+
Gen.has_output_grad(dist::HomogeneousMixture) = has_output_grad(dist.base_dist)
63+
Gen.has_argument_grads(dist::HomogeneousMixture) = (true, has_argument_grads(dist.base_dist)...)
64+
Gen.is_discrete(dist::HomogeneousMixture) = is_discrete(dist.base_dist)
65+
66+
function args_for_component(dist::HomogeneousMixture, k::Int, args)
67+
# returns a generator
68+
return (arg[fill(Colon(), dim)..., k]
69+
for (arg, dim) in zip(args, dist.dims))
70+
end
71+
72+
function Gen.random(dist::HomogeneousMixture, weights, args...)
73+
k = categorical(weights)
74+
return random(dist.base_dist, args_for_component(dist, k, args)...)
75+
end
76+
77+
function Gen.logpdf(dist::HomogeneousMixture, x, weights, args...)
78+
K = length(weights)
79+
log_densities = [Gen.logpdf(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K]
80+
log_densities = log_densities .+ log.(weights)
81+
return logsumexp(log_densities)
82+
end
83+
84+
function Gen.logpdf_grad(dist::HomogeneousMixture, x, weights, args...)
85+
K = length(weights)
86+
log_densities = [Gen.logpdf(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K]
87+
log_weighted_densities = log_densities .+ log.(weights)
88+
relative_weighted_densities = exp.(log_weighted_densities .- logsumexp(log_weighted_densities))
89+
90+
# log_grads[k] contains the gradients for the k'th component
91+
log_grads = [Gen.logpdf_grad(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K]
92+
93+
# compute gradient with respect to x
94+
log_grads_x = [log_grad[1] for log_grad in log_grads]
95+
x_grad = if has_output_grad(dist.base_dist)
96+
sum(log_grads_x .* relative_weighted_densities)
97+
else
98+
nothing
99+
end
100+
101+
# compute gradients with respect to the weights
102+
weights_grad = exp.(log_densities .- logsumexp(log_weighted_densities))
103+
104+
# compute gradients with respect to each argument
105+
arg_grads = Any[]
106+
for (i, (has_grad, arg, dim)) in enumerate(zip(has_argument_grads(dist)[2:end], args, dist.dims))
107+
if has_grad
108+
if dim == 0
109+
grads = [log_grad[i+1] for log_grad in log_grads]
110+
grad_weights = relative_weighted_densities
111+
else
112+
grads = cat(
113+
[log_grad[i+1] for log_grad in log_grads]...,
114+
dims=dist.dims[i]+1)
115+
grad_weights = reshape(
116+
relative_weighted_densities,
117+
(1 for d in 1:dist.dims[i])..., length(dist.dims))
118+
end
119+
push!(arg_grads, grads .* grad_weights)
120+
else
121+
push!(arg_grads, nothing)
122+
end
123+
end
124+
125+
return (x_grad, weights_grad, arg_grads...)
126+
end
127+
128+
export HomogeneousMixture
129+
130+
131+
##############################################################################
132+
# heterogeneous mixture: fixed number of potentially different distributions #
133+
##############################################################################
134+
135+
"""
136+
HeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T}
137+
138+
Define a new distribution that is a mixture of a given list of base distributions.
139+
140+
The argument is the vector of base distributions, one for each mixture component.
141+
142+
Note that the base distributions must have the same output type.
143+
144+
Example:
145+
```julia
146+
uniform_beta_mixture = HeterogeneousMixture([uniform, beta])
147+
```
148+
149+
The resulting mixture distribution takes `n+1` arguments, where `n` is the sum of the number of arguments taken by each distribution in the list.
150+
The first argument to the mixture distribution is a vector of non-negative mixture weights, which must sum to 1.0.
151+
The remaining arguments are the arguments to each mixture component distribution, in order in which the distributions are passed into the constructor.
152+
153+
Example:
154+
```julia
155+
@gen function foo()
156+
# mixure of a uniform distribution on the interval [`lower`, `upper`]
157+
# and a beta distribution with alpha parameter `a` and beta parameter `b`
158+
# the uniform as weight 0.4 and the beta has weight 0.6
159+
x ~ uniform_beta_mixture([0.4, 0.6], lower, upper, a, b)
160+
end
161+
```
162+
"""
163+
struct HeterogeneousMixture{T} <: Distribution{T}
164+
K::Int
165+
distributions::Vector{Distribution{T}}
166+
has_output_grad::Bool
167+
has_argument_grads::Tuple
168+
is_discrete::Bool
169+
num_args::Vector{Int}
170+
starting_args::Vector{Int}
171+
end
172+
173+
(dist::HeterogeneousMixture)(args...) = random(dist, args...)
174+
175+
Gen.has_output_grad(dist::HeterogeneousMixture) = dist.has_output_grad
176+
Gen.has_argument_grads(dist::HeterogeneousMixture) = dist.has_argument_grads
177+
Gen.is_discrete(dist::HeterogeneousMixture) = dist.is_discrete
178+
179+
const MIXTURE_WRONG_NUM_COMPONENTS_ERR = "the length of the weights vector does not match the number of mixture components"
180+
181+
function HeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T}
182+
_has_output_grad = true
183+
_has_argument_grads = Bool[true] # weights
184+
_is_discrete = true
185+
for dist in distributions
186+
_has_output_grad = _has_output_grad && has_output_grad(dist)
187+
for has_arg_grad in has_argument_grads(dist)
188+
push!(_has_argument_grads, has_arg_grad)
189+
end
190+
_is_discrete = _is_discrete && is_discrete(dist)
191+
end
192+
num_args = Int[]
193+
starting_args = Int[]
194+
for dist in distributions
195+
push!(starting_args, sum(num_args) + 1)
196+
push!(num_args, length(has_argument_grads(dist)))
197+
end
198+
K = length(distributions)
199+
return HeterogeneousMixture{T}(
200+
K, distributions,
201+
_has_output_grad,
202+
tuple(_has_argument_grads...),
203+
_is_discrete,
204+
num_args,
205+
starting_args)
206+
end
207+
208+
function extract_args_for_component(dist::HeterogeneousMixture, component_args_flat, k::Int)
209+
start_arg = dist.starting_args[k]
210+
n = dist.num_args[k]
211+
return component_args_flat[start_arg:start_arg+n-1]
212+
end
213+
214+
function Gen.random(dist::HeterogeneousMixture{T}, weights, component_args_flat...) where {T}
215+
(length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR)
216+
k = categorical(weights)
217+
value::T = random(
218+
dist.distributions[k],
219+
extract_args_for_component(dist, component_args_flat, k)...)
220+
return value
221+
end
222+
223+
function Gen.logpdf(dist::HeterogeneousMixture, x, weights, component_args_flat...)
224+
(length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR)
225+
log_densities = [Gen.logpdf(
226+
dist.distributions[k], x,
227+
extract_args_for_component(dist, component_args_flat, k)...)
228+
for k in 1:dist.K]
229+
log_densities = log_densities .+ log.(weights)
230+
return logsumexp(log_densities)
231+
end
232+
233+
function Gen.logpdf_grad(dist::HeterogeneousMixture, x, weights, component_args_flat...)
234+
(length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR)
235+
log_densities = [Gen.logpdf(
236+
dist.distributions[k], x,
237+
extract_args_for_component(dist, component_args_flat, k)...)
238+
for k in 1:dist.K]
239+
log_weighted_densities = log_densities .+ log.(weights)
240+
relative_weighted_densities = exp.(log_weighted_densities .- logsumexp(log_weighted_densities))
241+
242+
# log_grads[k] contains the gradients for that k in the mixture
243+
log_grads = [Gen.logpdf_grad(
244+
dist.distributions[k], x,
245+
extract_args_for_component(dist, component_args_flat, k)...)
246+
for k in 1:dist.K]
247+
248+
# gradient with respect to x
249+
log_grads_x = [log_grad[1] for log_grad in log_grads]
250+
x_grad = if has_output_grad(dist)
251+
sum(log_grads_x .* relative_weighted_densities)
252+
else
253+
nothing
254+
end
255+
256+
# gradients with respect to the weights
257+
weights_grad = exp.(log_densities .- logsumexp(log_weighted_densities))
258+
259+
# gradients with respect to each argument of each component
260+
component_arg_grads = Any[]
261+
cur = 1
262+
for k in 1:dist.K
263+
for i in 1:dist.num_args[k]
264+
if dist.has_argument_grads[cur]
265+
@assert log_grads[k][i+1] != nothing
266+
push!(component_arg_grads, relative_weighted_densities[k] * log_grads[k][i+1])
267+
else
268+
@assert log_grads[k][i+1] == nothing
269+
push!(component_arg_grads, nothing)
270+
end
271+
cur += 1
272+
end
273+
end
274+
275+
return (x_grad, weights_grad, component_arg_grads...)
276+
end
277+
278+
export HeterogeneousMixture

src/modeling_library/modeling_library.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Otherwise, this element contains the gradient with respect to the `i`th argument
4040
"""
4141
function logpdf_grad end
4242

43-
function is_discrete end
43+
is_discrete(::Distribution) = false # default
4444

4545
# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl
4646

@@ -59,6 +59,9 @@ include("distributions/distributions.jl")
5959
# @dist DSL
6060
include("dist_dsl/dist_dsl.jl")
6161

62+
# mixtures of distributions
63+
include("mixture.jl")
64+
6265
###############
6366
# combinators #
6467
###############

0 commit comments

Comments
 (0)