@@ -46,6 +46,29 @@ function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
46
46
end
47
47
48
48
49
+ """
50
+ _batched_elbos(flow, logp, xs)
51
+
52
+ Batched ELBO estimates that transforms a matrix of samples (each column represents a single
53
+ sample) in one call.
54
+ This is more efficient for invertible neural-network flows (RealNVP/NSF) as it leverages
55
+ the batched operation of the neural networks.
56
+
57
+ Inputs
58
+ - `flow::Bijectors.MultivariateTransformed`
59
+ - `logp`: function returning log-density of target
60
+ - `xs`: column-wise sample batch
61
+
62
+ Returns
63
+ - a vector of ELBO estimates for each sample in the batch
64
+ """
65
+ function _batched_elbos (flow:: Bijectors.MultivariateTransformed , logp, xs:: AbstractMatrix )
66
+ # requires the flow transformation to be able to handle batched inputs
67
+ ys, logabsdetjac = with_logabsdet_jacobian (flow. transform, xs)
68
+ elbos = logp (ys) .- logpdf (flow. dist, xs) .+ logabsdetjac
69
+ return elbos
70
+ end
71
+
49
72
"""
50
73
elbo_batch(flow, logp, xs)
51
74
elbo_batch([rng, ] flow, logp, n_samples)
@@ -64,14 +87,12 @@ Returns
64
87
- Scalar estimate of the ELBO
65
88
"""
66
89
function elbo_batch (flow:: Bijectors.MultivariateTransformed , logp, xs:: AbstractMatrix )
67
- # requires the flow transformation to be able to handle batched inputs
68
- ys, logabsdetjac = with_logabsdet_jacobian (flow. transform, xs)
69
- elbos = logp (ys) .- logpdf (flow. dist, xs) .+ logabsdetjac
70
- return elbos
90
+ elbos = _batched_elbos (flow, logp, xs)
91
+ return mean (elbos)
71
92
end
72
93
function elbo_batch (rng:: AbstractRNG , flow:: Bijectors.MultivariateTransformed , logp, n_samples)
73
94
xs = _device_specific_rand (rng, flow. dist, n_samples)
74
- elbos = elbo_batch (flow, logp, xs)
95
+ elbos = _batched_elbos (flow, logp, xs)
75
96
return mean (elbos)
76
97
end
77
98
elbo_batch (flow:: Bijectors.TransformedDistribution , logp, n_samples) =
0 commit comments