Skip to content

Commit 427225e

Browse files
committed
use elbo_batch for real_nvp; achieved 4+ times speed up
1 parent 0672626 commit 427225e

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

example/demo_RealNVP.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function AffineCoupling(
3636
return AffineCoupling(dim, mask, s, t)
3737
end
3838

39-
function Bijectors.transform(af::AffineCoupling, x::AbstractVector)
39+
function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
4040
# partition vector using 'af.mask::PartitionMask`
4141
x₁, x₂, x₃ = partition(af.mask, x)
4242
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
@@ -50,10 +50,18 @@ end
5050
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
5151
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
5252
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
53-
logjac = sum(log abs, af.s(x_2))
53+
logjac = sum(log abs, af.s(x_2)) # this is a scalar
5454
return combine(af.mask, y_1, x_2, x_3), logjac
5555
end
5656

57+
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
58+
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
59+
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
60+
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
61+
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
62+
end
63+
64+
5765
function Bijectors.with_logabsdet_jacobian(
5866
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
5967
)
@@ -66,10 +74,16 @@ function Bijectors.with_logabsdet_jacobian(
6674
return combine(af.mask, x_1, y_2, y_3), logjac
6775
end
6876

69-
function Bijectors.logabsdetjac(af::AffineCoupling, x::AbstractVector)
70-
_, x_2, _ = partition(af.mask, x)
71-
logjac = sum(log abs, af.s(x_2))
72-
return logjac
77+
function Bijectors.with_logabsdet_jacobian(
78+
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
79+
)
80+
af = iaf.orig
81+
# partition vector using `af.mask::PartitionMask`
82+
y_1, y_2, y_3 = partition(af.mask, y)
83+
# inverse transformation
84+
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
85+
logjac = -sum(log abs, af.s(y_2); dims = 1)
86+
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
7387
end
7488

7589
###################
@@ -126,6 +140,8 @@ q0 = MvNormal(zeros(T, 2), ones(T, 2))
126140

127141
d = 2
128142
hdims = 32
143+
144+
# alternating the coupling layers
129145
Ls = [AffineCoupling(d, hdims, [1]) AffineCoupling(d, hdims, [2]) for i in 1:3]
130146

131147
flow = create_flow(Ls, q0)
@@ -142,7 +158,8 @@ cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
142158
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
143159
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
144160
flow_trained, stats, _ = train_flow(
145-
elbo,
161+
rng,
162+
elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
146163
flow,
147164
logp,
148165
sample_per_iter;

0 commit comments

Comments
 (0)