Skip to content

Commit 06e7651

Browse files
committed
reduce the iter number
1 parent cc3c761 commit 06e7651

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

example/demo_RealNVP.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ flow_trained, stats, _ = train_flow(
163163
flow,
164164
logp,
165165
sample_per_iter;
166-
max_iters=50_000,
166+
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
167167
optimiser=Optimisers.Adam(5e-4),
168168
ADbackend=adtype,
169169
show_progress=true,

example/demo_hamiltonian_flow.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ flow_trained, stats, _ = train_flow(
162162
flow,
163163
logp_joint,
164164
sample_per_iter;
165-
max_iters=50_000,
165+
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
166166
optimiser=Optimisers.Adam(3e-4),
167167
ADbackend=adtype,
168168
show_progress=true,

example/demo_neural_spline_flow.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ flow_trained, stats, _ = train_flow(
155155
flow,
156156
logp,
157157
sample_per_iter;
158-
max_iters=50_000,
158+
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
159159
optimiser=Optimisers.Adam(5e-5),
160160
ADbackend=adtype,
161161
show_progress=true,

example/demo_planar_flow.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ flow_untrained = deepcopy(flow)
3535
######################################
3636
# start training
3737
######################################
38-
sample_per_iter = 30
38+
sample_per_iter = 32
3939

4040
# callback function to log training progress
4141
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
4242
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
4343
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
4444
flow_trained, stats, _ = train_flow(
45-
elbo,
45+
elbo_batch,
4646
flow,
4747
logp,
4848
sample_per_iter;

example/demo_radial_flow.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ flow_untrained = deepcopy(flow)
3636
######################################
3737
# start training
3838
######################################
39-
sample_per_iter = 30
39+
sample_per_iter = 32
4040

4141
# callback function to log training progress
4242
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
4343
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
4444
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
4545
flow_trained, stats, _ = train_flow(
46-
elbo,
46+
elbo_batch,
4747
flow,
4848
logp,
4949
sample_per_iter;

0 commit comments

Comments
 (0)