@@ -46,6 +46,29 @@ function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
4646end
4747
4848
49+ """
50+ _batched_elbos(flow, logp, xs)
51+
52+ Batched ELBO estimates that transforms a matrix of samples (each column represents a single
53+ sample) in one call.
54+ This is more efficient for invertible neural-network flows (RealNVP/NSF) as it leverages
55+ the batched operation of the neural networks.
56+
57+ Inputs
58+ - `flow::Bijectors.MultivariateTransformed`
59+ - `logp`: function returning log-density of target
60+ - `xs`: column-wise sample batch
61+
62+ Returns
63+ - a vector of ELBO estimates for each sample in the batch
64+ """
65+ function _batched_elbos (flow:: Bijectors.MultivariateTransformed , logp, xs:: AbstractMatrix )
66+ # requires the flow transformation to be able to handle batched inputs
67+ ys, logabsdetjac = with_logabsdet_jacobian (flow. transform, xs)
68+ elbos = logp (ys) .- logpdf (flow. dist, xs) .+ logabsdetjac
69+ return elbos
70+ end
71+
4972"""
5073 elbo_batch(flow, logp, xs)
5174 elbo_batch([rng, ] flow, logp, n_samples)
@@ -64,14 +87,12 @@ Returns
6487- Scalar estimate of the ELBO
6588"""
6689function elbo_batch (flow:: Bijectors.MultivariateTransformed , logp, xs:: AbstractMatrix )
67- # requires the flow transformation to be able to handle batched inputs
68- ys, logabsdetjac = with_logabsdet_jacobian (flow. transform, xs)
69- elbos = logp (ys) .- logpdf (flow. dist, xs) .+ logabsdetjac
70- return elbos
90+ elbos = _batched_elbos (flow, logp, xs)
91+ return mean (elbos)
7192end
7293function elbo_batch (rng:: AbstractRNG , flow:: Bijectors.MultivariateTransformed , logp, n_samples)
7394 xs = _device_specific_rand (rng, flow. dist, n_samples)
74- elbos = elbo_batch (flow, logp, xs)
95+ elbos = _batched_elbos (flow, logp, xs)
7596 return mean (elbos)
7697end
7798elbo_batch (flow:: Bijectors.TransformedDistribution , logp, n_samples) =
0 commit comments