Skip to content

Commit 0fe536f

Browse files
committed
rm time_elapsed from train_flow
1 parent db4872b commit 0fe536f

File tree

3 files changed

+38
-35
lines changed

3 files changed

+38
-35
lines changed

src/NormalizingFlows.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ Train the given normalizing flow `flow` by calling `optimize`.
2828
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
2929
- `ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()`:
3030
automatic differentiation backend, currently supports
31-
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`.
31+
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`,
32+
`ADTypes.AutoMooncake()` and
33+
`ADTypes.AutoEnzyme(;
34+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
35+
function_annotation=Enzyme.Const,
36+
)`.
37+
If user wants to use `AutoEnzyme`, please make sure to include the `set_runtime_activity` and `function_annotation` as shown above.
3238
- `kwargs...`: additional keyword arguments for `optimize` (See [`optimize`](@ref) for details)
3339
3440
# Returns
@@ -58,7 +64,7 @@ function train_flow(
5864
loss(θ, rng, args...) = -vo(rng, re(θ), args...)
5965

6066
# Normalizing flow training loop
61-
θ_flat_trained, opt_stats, st, time_elapsed = optimize(
67+
θ_flat_trained, opt_stats, st = optimize(
6268
ADbackend,
6369
loss,
6470
θ_flat,
@@ -71,7 +77,7 @@ function train_flow(
7177
)
7278

7379
flow_trained = re(θ_flat_trained)
74-
return flow_trained, opt_stats, st, time_elapsed
80+
return flow_trained, opt_stats, st
7581
end
7682

7783
include("optimize.jl")

src/optimize.jl

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ function _value_and_gradient(loss, prep, adbackend, θ, args...)
1313
return DI.value_and_gradient(loss, prep, adbackend, θ, map(DI.Constant, args)...)
1414
end
1515

16-
1716
"""
1817
optimize(
1918
ad::ADTypes.AbstractADType,
@@ -58,7 +57,7 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
5857
function optimize(
5958
adbackend,
6059
loss,
61-
θ₀::AbstractVector{<:Real},
60+
θ₀::AbstractVector{<:Real},
6261
reconstruct,
6362
args...;
6463
max_iters::Int=10000,
@@ -70,42 +69,40 @@ function optimize(
7069
max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress
7170
),
7271
)
73-
time_elapsed = @elapsed begin
74-
opt_stats = []
72+
opt_stats = []
7573

76-
# prepare loss and autograd
77-
θ = deepcopy(θ₀)
78-
# grad = similar(θ)
79-
prep = _prepare_gradient(loss, adbackend, θ₀, args...)
74+
# prepare loss and autograd
75+
θ = deepcopy(θ₀)
76+
# grad = similar(θ)
77+
prep = _prepare_gradient(loss, adbackend, θ₀, args...)
8078

81-
# initialise optimiser state
82-
st = Optimisers.setup(optimiser, θ)
79+
# initialise optimiser state
80+
st = Optimisers.setup(optimiser, θ)
8381

84-
# general `hasconverged(...)` approach to allow early termination.
85-
converged = false
86-
i = 1
87-
while (i max_iters) && !converged
88-
ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...)
82+
# general `hasconverged(...)` approach to allow early termination.
83+
converged = false
84+
i = 1
85+
while (i max_iters) && !converged
86+
ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...)
8987

90-
# Save stats
91-
stat = (iteration=i, loss=ls, gradient_norm=norm(g))
88+
# Save stats
89+
stat = (iteration=i, loss=ls, gradient_norm=norm(g))
9290

93-
# callback
94-
if callback !== nothing
95-
new_stat = callback(i, opt_stats, reconstruct, θ)
96-
stat = new_stat !== nothing ? merge(stat, new_stat) : stat
97-
end
98-
push!(opt_stats, stat)
91+
# callback
92+
if callback !== nothing
93+
new_stat = callback(i, opt_stats, reconstruct, θ)
94+
stat = new_stat !== nothing ? merge(stat, new_stat) : stat
95+
end
96+
push!(opt_stats, stat)
9997

100-
# update optimiser state and parameters
101-
st, θ = Optimisers.update!(st, θ, g)
98+
# update optimiser state and parameters
99+
st, θ = Optimisers.update!(st, θ, g)
102100

103-
# check convergence
104-
i += 1
105-
converged = hasconverged(i, stat, reconstruct, θ, st)
106-
pm_next!(prog, stat)
107-
end
101+
# check convergence
102+
i += 1
103+
converged = hasconverged(i, stat, reconstruct, θ, st)
104+
pm_next!(prog, stat)
108105
end
109106
# return status of the optimiser for potential continuation of training
110-
return θ, map(identity, opt_stats), st, time_elapsed
107+
return θ, map(identity, opt_stats), st
111108
end

test/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
sample_per_iter = 10
2828
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
2929
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
30-
flow_trained, stats, _, _ = train_flow(
30+
flow_trained, stats, _ = train_flow(
3131
elbo,
3232
flow,
3333
logp,

0 commit comments

Comments
 (0)