@@ -36,7 +36,7 @@ function AffineCoupling(
36
36
return AffineCoupling (dim, mask, s, t)
37
37
end
38
38
39
- function Bijectors. transform (af:: AffineCoupling , x:: AbstractVector )
39
+ function Bijectors. transform (af:: AffineCoupling , x:: AbstractVecOrMat )
40
40
# partition vector using 'af.mask::PartitionMask`
41
41
x₁, x₂, x₃ = partition (af. mask, x)
42
42
y₁ = x₁ .* af. s (x₂) .+ af. t (x₂)
50
50
function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractVector )
51
51
x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
52
52
y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
53
- logjac = sum (log ∘ abs, af. s (x_2))
53
+ logjac = sum (log ∘ abs, af. s (x_2)) # this is a scalar
54
54
return combine (af. mask, y_1, x_2, x_3), logjac
55
55
end
56
56
57
+ function Bijectors. with_logabsdet_jacobian (af:: AffineCoupling , x:: AbstractMatrix )
58
+ x_1, x_2, x_3 = Bijectors. partition (af. mask, x)
59
+ y_1 = af. s (x_2) .* x_1 .+ af. t (x_2)
60
+ logjac = sum (log ∘ abs, af. s (x_2); dims = 1 ) # 1 × size(x, 2)
61
+ return combine (af. mask, y_1, x_2, x_3), vec (logjac)
62
+ end
63
+
64
+
57
65
function Bijectors. with_logabsdet_jacobian (
58
66
iaf:: Inverse{<:AffineCoupling} , y:: AbstractVector
59
67
)
@@ -66,10 +74,16 @@ function Bijectors.with_logabsdet_jacobian(
66
74
return combine (af. mask, x_1, y_2, y_3), logjac
67
75
end
68
76
69
- function Bijectors. logabsdetjac (af:: AffineCoupling , x:: AbstractVector )
70
- _, x_2, _ = partition (af. mask, x)
71
- logjac = sum (log ∘ abs, af. s (x_2))
72
- return logjac
77
+ function Bijectors. with_logabsdet_jacobian (
78
+ iaf:: Inverse{<:AffineCoupling} , y:: AbstractMatrix
79
+ )
80
+ af = iaf. orig
81
+ # partition vector using `af.mask::PartitionMask`
82
+ y_1, y_2, y_3 = partition (af. mask, y)
83
+ # inverse transformation
84
+ x_1 = (y_1 .- af. t (y_2)) ./ af. s (y_2)
85
+ logjac = - sum (log ∘ abs, af. s (y_2); dims = 1 )
86
+ return combine (af. mask, x_1, y_2, y_3), vec (logjac)
73
87
end
74
88
75
89
# ##################
@@ -126,6 +140,8 @@ q0 = MvNormal(zeros(T, 2), ones(T, 2))
126
140
127
141
d = 2
128
142
hdims = 32
143
+
144
+ # alternating the coupling layers
129
145
Ls = [AffineCoupling (d, hdims, [1 ]) ∘ AffineCoupling (d, hdims, [2 ]) for i in 1 : 3 ]
130
146
131
147
flow = create_flow (Ls, q0)
@@ -142,7 +158,8 @@ cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
142
158
adtype = ADTypes. AutoMooncake (; config = Mooncake. Config ())
143
159
checkconv (iter, stat, re, θ, st) = stat. gradient_norm < one (T)/ 1000
144
160
flow_trained, stats, _ = train_flow (
145
- elbo,
161
+ rng,
162
+ elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
146
163
flow,
147
164
logp,
148
165
sample_per_iter;
0 commit comments