|
| 1 | +################################################################## |
| 2 | +# homogeneous mixture: arbitrary number of the same distribution # |
| 3 | +################################################################## |
| 4 | + |
| 5 | +""" |
| 6 | + HomogeneousMixture(distribution::Distribution, dims::Vector{Int}) |
| 7 | +
|
| 8 | +Define a new distribution that is a mixture of some number of instances of single base distributions. |
| 9 | +
|
| 10 | +The first argument defines the base distribution of each component in the mixture. |
| 11 | +
|
| 12 | +The second argument must have length equal |
| 13 | +to the number of arguments taken by the base distribution. A value of 0 |
| 14 | +at a position in the vector an indicates that the corresponding argument to the |
| 15 | +base distribution is a scalar, and integer values of i for i >= 1 indicate that |
| 16 | +the corresponding argument is an i-dimensional array. |
| 17 | +
|
| 18 | +Example: |
| 19 | +
|
| 20 | +```julia |
| 21 | +mixture_of_normals = HomogeneousMixture(normal, [0, 0]) |
| 22 | +``` |
| 23 | +
|
| 24 | +The resulting distribution (e.g. `mixture_of_normals` above) can then be used like the built-in distribution values like `normal`. |
| 25 | +The distribution takes `n+1` arguments where `n` is the number of arguments taken by the base distribution. |
| 26 | +The first argument to the distribution is a vector of non-negative mixture weights, which must sum to 1.0. |
| 27 | +The remaining arguments to the distribution correspond to the arguments of the base distribution, but have a different type: |
| 28 | +If an argument to the base distribution is a scalar of type `T`, then the corresponding argument to the mixture distribution is a `Vector{T}`, where each element of this vector is the argument to the corresponding mixture component. |
| 29 | +If an argument to the base distribution is an `Array{T,N}` for some `N`, then the corresponding argument to the mixture distribution is of the form `arr::Array{T,N+1}`, where each slice of the array of the form `arr[:,:,...,i]` is the argument for the `i`th mixture component. |
| 30 | +
|
| 31 | +Example: |
| 32 | +
|
| 33 | +```julia |
| 34 | +mixture_of_normals = HomogeneousMixture(normal, [0, 0]) |
| 35 | +mixture_of_mvnormals = HomogeneousMixture(mvnormal, [1, 2]) |
| 36 | +
|
| 37 | +@gen function foo() |
| 38 | + # mixture of two normal distributions |
| 39 | + # with means -1.0 and 1.0 |
| 40 | + # and standard deviations 0.1 and 10.0 |
| 41 | + # the first normal distribution has weight 0.4; the second has weight 0.6 |
| 42 | + x ~ mixture_of_normals([0.4, 0.6], [-1.0, 1.0], [0.1, 10.0]) |
| 43 | +
|
| 44 | + # mixture of two multivariate normal distributions |
| 45 | + # with means: [0.0, 0.0] and [1.0, 1.0] |
| 46 | + # and covariance matrices: [1.0 0.0; 0.0 1.0] and [10.0 0.0; 0.0 10.0] |
| 47 | + # the first multivariate normal distribution has weight 0.4; |
| 48 | + # the second has weight 0.6 |
| 49 | + means = [0.0 1.0; 0.0 1.0] # or, cat([0.0, 0.0], [1.0, 1.0], dims=2) |
| 50 | + covs = cat([1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0], dims=3) |
| 51 | + y ~ mixture_of_mvnormals([0.4, 0.6], means, covs) |
| 52 | +end |
| 53 | +``` |
| 54 | +""" |
| 55 | +struct HomogeneousMixture{T} <: Distribution{T} |
| 56 | + base_dist::Distribution{T} |
| 57 | + dims::Vector{Int} |
| 58 | +end |
| 59 | + |
| 60 | +(dist::HomogeneousMixture)(args...) = random(dist, args...) |
| 61 | + |
| 62 | +Gen.has_output_grad(dist::HomogeneousMixture) = has_output_grad(dist.base_dist) |
| 63 | +Gen.has_argument_grads(dist::HomogeneousMixture) = (true, has_argument_grads(dist.base_dist)...) |
| 64 | +Gen.is_discrete(dist::HomogeneousMixture) = is_discrete(dist.base_dist) |
| 65 | + |
| 66 | +function args_for_component(dist::HomogeneousMixture, k::Int, args) |
| 67 | + # returns a generator |
| 68 | + return (arg[fill(Colon(), dim)..., k] |
| 69 | + for (arg, dim) in zip(args, dist.dims)) |
| 70 | +end |
| 71 | + |
| 72 | +function Gen.random(dist::HomogeneousMixture, weights, args...) |
| 73 | + k = categorical(weights) |
| 74 | + return random(dist.base_dist, args_for_component(dist, k, args)...) |
| 75 | +end |
| 76 | + |
| 77 | +function Gen.logpdf(dist::HomogeneousMixture, x, weights, args...) |
| 78 | + K = length(weights) |
| 79 | + log_densities = [Gen.logpdf(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K] |
| 80 | + log_densities = log_densities .+ log.(weights) |
| 81 | + return logsumexp(log_densities) |
| 82 | +end |
| 83 | + |
| 84 | +function Gen.logpdf_grad(dist::HomogeneousMixture, x, weights, args...) |
| 85 | + K = length(weights) |
| 86 | + log_densities = [Gen.logpdf(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K] |
| 87 | + log_weighted_densities = log_densities .+ log.(weights) |
| 88 | + relative_weighted_densities = exp.(log_weighted_densities .- logsumexp(log_weighted_densities)) |
| 89 | + |
| 90 | + # log_grads[k] contains the gradients for the k'th component |
| 91 | + log_grads = [Gen.logpdf_grad(dist.base_dist, x, args_for_component(dist, k, args)...) for k in 1:K] |
| 92 | + |
| 93 | + # compute gradient with respect to x |
| 94 | + log_grads_x = [log_grad[1] for log_grad in log_grads] |
| 95 | + x_grad = if has_output_grad(dist.base_dist) |
| 96 | + sum(log_grads_x .* relative_weighted_densities) |
| 97 | + else |
| 98 | + nothing |
| 99 | + end |
| 100 | + |
| 101 | + # compute gradients with respect to the weights |
| 102 | + weights_grad = exp.(log_densities .- logsumexp(log_weighted_densities)) |
| 103 | + |
| 104 | + # compute gradients with respect to each argument |
| 105 | + arg_grads = Any[] |
| 106 | + for (i, (has_grad, arg, dim)) in enumerate(zip(has_argument_grads(dist)[2:end], args, dist.dims)) |
| 107 | + if has_grad |
| 108 | + if dim == 0 |
| 109 | + grads = [log_grad[i+1] for log_grad in log_grads] |
| 110 | + grad_weights = relative_weighted_densities |
| 111 | + else |
| 112 | + grads = cat( |
| 113 | + [log_grad[i+1] for log_grad in log_grads]..., |
| 114 | + dims=dist.dims[i]+1) |
| 115 | + grad_weights = reshape( |
| 116 | + relative_weighted_densities, |
| 117 | + (1 for d in 1:dist.dims[i])..., length(dist.dims)) |
| 118 | + end |
| 119 | + push!(arg_grads, grads .* grad_weights) |
| 120 | + else |
| 121 | + push!(arg_grads, nothing) |
| 122 | + end |
| 123 | + end |
| 124 | + |
| 125 | + return (x_grad, weights_grad, arg_grads...) |
| 126 | +end |
| 127 | + |
| 128 | +export HomogeneousMixture |
| 129 | + |
| 130 | + |
| 131 | +############################################################################## |
| 132 | +# heterogeneous mixture: fixed number of potentially different distributions # |
| 133 | +############################################################################## |
| 134 | + |
| 135 | +""" |
| 136 | + HeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T} |
| 137 | +
|
| 138 | +Define a new distribution that is a mixture of a given list of base distributions. |
| 139 | +
|
| 140 | +The argument is the vector of base distributions, one for each mixture component. |
| 141 | +
|
| 142 | +Note that the base distributions must have the same output type. |
| 143 | +
|
| 144 | +Example: |
| 145 | +```julia |
| 146 | +uniform_beta_mixture = HeterogeneousMixture([uniform, beta]) |
| 147 | +``` |
| 148 | +
|
| 149 | +The resulting mixture distribution takes `n+1` arguments, where `n` is the sum of the number of arguments taken by each distribution in the list. |
| 150 | +The first argument to the mixture distribution is a vector of non-negative mixture weights, which must sum to 1.0. |
| 151 | +The remaining arguments are the arguments to each mixture component distribution, in order in which the distributions are passed into the constructor. |
| 152 | +
|
| 153 | +Example: |
| 154 | +```julia |
| 155 | +@gen function foo() |
| 156 | + # mixure of a uniform distribution on the interval [`lower`, `upper`] |
| 157 | + # and a beta distribution with alpha parameter `a` and beta parameter `b` |
| 158 | + # the uniform as weight 0.4 and the beta has weight 0.6 |
| 159 | + x ~ uniform_beta_mixture([0.4, 0.6], lower, upper, a, b) |
| 160 | +end |
| 161 | +``` |
| 162 | +""" |
| 163 | +struct HeterogeneousMixture{T} <: Distribution{T} |
| 164 | + K::Int |
| 165 | + distributions::Vector{Distribution{T}} |
| 166 | + has_output_grad::Bool |
| 167 | + has_argument_grads::Tuple |
| 168 | + is_discrete::Bool |
| 169 | + num_args::Vector{Int} |
| 170 | + starting_args::Vector{Int} |
| 171 | +end |
| 172 | + |
| 173 | +(dist::HeterogeneousMixture)(args...) = random(dist, args...) |
| 174 | + |
| 175 | +Gen.has_output_grad(dist::HeterogeneousMixture) = dist.has_output_grad |
| 176 | +Gen.has_argument_grads(dist::HeterogeneousMixture) = dist.has_argument_grads |
| 177 | +Gen.is_discrete(dist::HeterogeneousMixture) = dist.is_discrete |
| 178 | + |
| 179 | +const MIXTURE_WRONG_NUM_COMPONENTS_ERR = "the length of the weights vector does not match the number of mixture components" |
| 180 | + |
| 181 | +function HeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T} |
| 182 | + _has_output_grad = true |
| 183 | + _has_argument_grads = Bool[true] # weights |
| 184 | + _is_discrete = true |
| 185 | + for dist in distributions |
| 186 | + _has_output_grad = _has_output_grad && has_output_grad(dist) |
| 187 | + for has_arg_grad in has_argument_grads(dist) |
| 188 | + push!(_has_argument_grads, has_arg_grad) |
| 189 | + end |
| 190 | + _is_discrete = _is_discrete && is_discrete(dist) |
| 191 | + end |
| 192 | + num_args = Int[] |
| 193 | + starting_args = Int[] |
| 194 | + for dist in distributions |
| 195 | + push!(starting_args, sum(num_args) + 1) |
| 196 | + push!(num_args, length(has_argument_grads(dist))) |
| 197 | + end |
| 198 | + K = length(distributions) |
| 199 | + return HeterogeneousMixture{T}( |
| 200 | + K, distributions, |
| 201 | + _has_output_grad, |
| 202 | + tuple(_has_argument_grads...), |
| 203 | + _is_discrete, |
| 204 | + num_args, |
| 205 | + starting_args) |
| 206 | +end |
| 207 | + |
| 208 | +function extract_args_for_component(dist::HeterogeneousMixture, component_args_flat, k::Int) |
| 209 | + start_arg = dist.starting_args[k] |
| 210 | + n = dist.num_args[k] |
| 211 | + return component_args_flat[start_arg:start_arg+n-1] |
| 212 | +end |
| 213 | + |
| 214 | +function Gen.random(dist::HeterogeneousMixture{T}, weights, component_args_flat...) where {T} |
| 215 | + (length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR) |
| 216 | + k = categorical(weights) |
| 217 | + value::T = random( |
| 218 | + dist.distributions[k], |
| 219 | + extract_args_for_component(dist, component_args_flat, k)...) |
| 220 | + return value |
| 221 | +end |
| 222 | + |
| 223 | +function Gen.logpdf(dist::HeterogeneousMixture, x, weights, component_args_flat...) |
| 224 | + (length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR) |
| 225 | + log_densities = [Gen.logpdf( |
| 226 | + dist.distributions[k], x, |
| 227 | + extract_args_for_component(dist, component_args_flat, k)...) |
| 228 | + for k in 1:dist.K] |
| 229 | + log_densities = log_densities .+ log.(weights) |
| 230 | + return logsumexp(log_densities) |
| 231 | +end |
| 232 | + |
| 233 | +function Gen.logpdf_grad(dist::HeterogeneousMixture, x, weights, component_args_flat...) |
| 234 | + (length(weights) != dist.K) && error(MIXTURE_WRONG_NUM_COMPONENTS_ERR) |
| 235 | + log_densities = [Gen.logpdf( |
| 236 | + dist.distributions[k], x, |
| 237 | + extract_args_for_component(dist, component_args_flat, k)...) |
| 238 | + for k in 1:dist.K] |
| 239 | + log_weighted_densities = log_densities .+ log.(weights) |
| 240 | + relative_weighted_densities = exp.(log_weighted_densities .- logsumexp(log_weighted_densities)) |
| 241 | + |
| 242 | + # log_grads[k] contains the gradients for that k in the mixture |
| 243 | + log_grads = [Gen.logpdf_grad( |
| 244 | + dist.distributions[k], x, |
| 245 | + extract_args_for_component(dist, component_args_flat, k)...) |
| 246 | + for k in 1:dist.K] |
| 247 | + |
| 248 | + # gradient with respect to x |
| 249 | + log_grads_x = [log_grad[1] for log_grad in log_grads] |
| 250 | + x_grad = if has_output_grad(dist) |
| 251 | + sum(log_grads_x .* relative_weighted_densities) |
| 252 | + else |
| 253 | + nothing |
| 254 | + end |
| 255 | + |
| 256 | + # gradients with respect to the weights |
| 257 | + weights_grad = exp.(log_densities .- logsumexp(log_weighted_densities)) |
| 258 | + |
| 259 | + # gradients with respect to each argument of each component |
| 260 | + component_arg_grads = Any[] |
| 261 | + cur = 1 |
| 262 | + for k in 1:dist.K |
| 263 | + for i in 1:dist.num_args[k] |
| 264 | + if dist.has_argument_grads[cur] |
| 265 | + @assert log_grads[k][i+1] != nothing |
| 266 | + push!(component_arg_grads, relative_weighted_densities[k] * log_grads[k][i+1]) |
| 267 | + else |
| 268 | + @assert log_grads[k][i+1] == nothing |
| 269 | + push!(component_arg_grads, nothing) |
| 270 | + end |
| 271 | + cur += 1 |
| 272 | + end |
| 273 | + end |
| 274 | + |
| 275 | + return (x_grad, weights_grad, component_arg_grads...) |
| 276 | +end |
| 277 | + |
| 278 | +export HeterogeneousMixture |
0 commit comments