|
43 | 43 | function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
|
44 | 44 | return elbo(Random.default_rng(), flow, logp, n_samples)
|
45 | 45 | end
|
| 46 | + |
| 47 | + |
| 48 | +""" |
| 49 | + elbo_batch(flow, logp, xs) |
| 50 | + elbo_batch([rng, ]flow, logp, n_samples) |
| 51 | +
|
| 52 | +Instead of broadcasting over elbo_single_sample, this function directly |
| 53 | +computes the ELBO in a batched manner, which requires the flow.transform to be able to |
| 54 | +handle batched transformation directly. |
| 55 | +
|
| 56 | +This will be more efficient than `elbo` for invertible neural networks such as RealNVP, |
| 57 | +Neural Spline Flow, etc. |
| 58 | +
|
| 59 | +# Arguments |
| 60 | +- `rng`: random number generator |
| 61 | +- `flow`: variational distribution to be trained. In particular |
| 62 | + `flow = transformed(q₀, T::Bijectors.Bijector)`, |
| 63 | + q₀ is a reference distribution that one can easily sample and compute logpdf |
| 64 | +- `logp`: log-pdf of the target distribution (not necessarily normalized) |
| 65 | +- `xs`: samples from reference dist q₀ |
| 66 | +- `n_samples`: number of samples from reference dist q₀ |
| 67 | +
|
| 68 | +""" |
| 69 | +function elbo_batch(flow::Bijectors.MultivariateTransformed, logp, xs::AbstractMatrix) |
| 70 | + # requires the flow transformation to be able to handle batched inputs |
| 71 | + ys, logabsdetjac = with_logabsdet_jacobian(flow.transform, xs) |
| 72 | + elbos = logp(ys) .- logpdf(flow.dist, xs) .+ logabsdetjac |
| 73 | + return elbos |
| 74 | +end |
| 75 | +function elbo_batch(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, n_samples) |
| 76 | + xs = _device_specific_rand(rng, flow.dist, n_samples) |
| 77 | + elbos = elbo_batch(flow, logp, xs) |
| 78 | + return mean(elbos) |
| 79 | +end |
| 80 | +elbo_batch(flow::Bijectors.UnivariateTransformed, logp, n_samples) = |
| 81 | + elbo_batch(Random.default_rng(), flow, logp, n_samples) |
0 commit comments