Skip to content

Commit e257fab

Browse files
committed
add elbo_batch implementation; much faster for invertible NN based flows
1 parent 6b1500c commit e257fab

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/NormalizingFlows.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import DifferentiationInterface as DI
1212

1313
using DocStringExtensions
1414

15-
export train_flow, elbo, loglikelihood
15+
export train_flow, elbo, elbo_batch, loglikelihood
1616

1717
"""
1818
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)

src/objectives/elbo.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,39 @@ end
4343
function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
4444
return elbo(Random.default_rng(), flow, logp, n_samples)
4545
end
46+
47+
48+
"""
49+
elbo_batch(flow, logp, xs)
50+
elbo_batch([rng, ]flow, logp, n_samples)
51+
52+
Instead of broadcasting over elbo_single_sample, this function directly
53+
computes the ELBO in a batched manner, which requires the flow.transform to be able to
54+
handle batched transformation directly.
55+
56+
This will be more efficient than `elbo` for invertible neural networks such as RealNVP,
57+
Neural Spline Flow, etc.
58+
59+
# Arguments
60+
- `rng`: random number generator
61+
- `flow`: variational distribution to be trained. In particular
62+
`flow = transformed(q₀, T::Bijectors.Bijector)`,
63+
q₀ is a reference distribution that one can easily sample and compute logpdf
64+
- `logp`: log-pdf of the target distribution (not necessarily normalized)
65+
- `xs`: samples from reference dist q₀
66+
- `n_samples`: number of samples from reference dist q₀
67+
68+
"""
69+
function elbo_batch(flow::Bijectors.MultivariateTransformed, logp, xs::AbstractMatrix)
70+
# requires the flow transformation to be able to handle batched inputs
71+
ys, logabsdetjac = with_logabsdet_jacobian(flow.transform, xs)
72+
elbos = logp(ys) .- logpdf(flow.dist, xs) .+ logabsdetjac
73+
return elbos
74+
end
75+
function elbo_batch(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, n_samples)
76+
xs = _device_specific_rand(rng, flow.dist, n_samples)
77+
elbos = elbo_batch(flow, logp, xs)
78+
return mean(elbos)
79+
end
80+
elbo_batch(flow::Bijectors.UnivariateTransformed, logp, n_samples) =
81+
elbo_batch(Random.default_rng(), flow, logp, n_samples)

0 commit comments

Comments
 (0)