Skip to content

Commit d1ec834

Browse files
committed
fixing test
1 parent 104e5dd commit d1ec834

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/optimize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
6666
"""
6767
function optimize(
6868
adbackend,
69-
loss::Function,
69+
loss,
7070
θ₀::AbstractVector{<:Real},
71-
reconstruct::Function,
71+
reconstruct,
7272
args...;
7373
max_iters::Int=10000,
7474
optimiser::Optimisers.AbstractRule=Optimisers.ADAM(),

test/interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
target = MvNormal(μ, Σ)
1414
logp(z) = logpdf(target, z)
1515

16+
@leaf MvNormal
1617
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
1718
flow = Bijectors.transformed(
1819
q₀, Bijectors.Shift(zero.(μ)) Bijectors.Scale(ones(T, 2))
@@ -21,7 +22,7 @@
2122
sample_per_iter = 10
2223
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,)
2324
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3
24-
flow_trained, stats, _ = train_flow(
25+
flow_trained, stats, _, _ = train_flow(
2526
elbo,
2627
flow,
2728
logp,

0 commit comments

Comments
 (0)