@@ -13,7 +13,6 @@ function _value_and_gradient(loss, prep, adbackend, θ, args...)
13
13
return DI. value_and_gradient (loss, prep, adbackend, θ, map (DI. Constant, args)... )
14
14
end
15
15
16
-
17
16
"""
18
17
optimize(
19
18
ad::ADTypes.AbstractADType,
@@ -58,7 +57,7 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
58
57
function optimize (
59
58
adbackend,
60
59
loss,
61
- θ₀:: AbstractVector{<:Real} ,
60
+ θ₀:: AbstractVector{<:Real} ,
62
61
reconstruct,
63
62
args... ;
64
63
max_iters:: Int = 10000 ,
@@ -70,42 +69,40 @@ function optimize(
70
69
max_iters; desc= " Training" , barlen= 31 , showspeed= true , enabled= show_progress
71
70
),
72
71
)
73
- time_elapsed = @elapsed begin
74
- opt_stats = []
72
+ opt_stats = []
75
73
76
- # prepare loss and autograd
77
- θ = deepcopy (θ₀)
78
- # grad = similar(θ)
79
- prep = _prepare_gradient (loss, adbackend, θ₀, args... )
74
+ # prepare loss and autograd
75
+ θ = deepcopy (θ₀)
76
+ # grad = similar(θ)
77
+ prep = _prepare_gradient (loss, adbackend, θ₀, args... )
80
78
81
- # initialise optimiser state
82
- st = Optimisers. setup (optimiser, θ)
79
+ # initialise optimiser state
80
+ st = Optimisers. setup (optimiser, θ)
83
81
84
- # general `hasconverged(...)` approach to allow early termination.
85
- converged = false
86
- i = 1
87
- while (i ≤ max_iters) && ! converged
88
- ls, g = _value_and_gradient (loss, prep, adbackend, θ, args... )
82
+ # general `hasconverged(...)` approach to allow early termination.
83
+ converged = false
84
+ i = 1
85
+ while (i ≤ max_iters) && ! converged
86
+ ls, g = _value_and_gradient (loss, prep, adbackend, θ, args... )
89
87
90
- # Save stats
91
- stat = (iteration= i, loss= ls, gradient_norm= norm (g))
88
+ # Save stats
89
+ stat = (iteration= i, loss= ls, gradient_norm= norm (g))
92
90
93
- # callback
94
- if callback != = nothing
95
- new_stat = callback (i, opt_stats, reconstruct, θ)
96
- stat = new_stat != = nothing ? merge (stat, new_stat) : stat
97
- end
98
- push! (opt_stats, stat)
91
+ # callback
92
+ if callback != = nothing
93
+ new_stat = callback (i, opt_stats, reconstruct, θ)
94
+ stat = new_stat != = nothing ? merge (stat, new_stat) : stat
95
+ end
96
+ push! (opt_stats, stat)
99
97
100
- # update optimiser state and parameters
101
- st, θ = Optimisers. update! (st, θ, g)
98
+ # update optimiser state and parameters
99
+ st, θ = Optimisers. update! (st, θ, g)
102
100
103
- # check convergence
104
- i += 1
105
- converged = hasconverged (i, stat, reconstruct, θ, st)
106
- pm_next! (prog, stat)
107
- end
101
+ # check convergence
102
+ i += 1
103
+ converged = hasconverged (i, stat, reconstruct, θ, st)
104
+ pm_next! (prog, stat)
108
105
end
109
106
# return status of the optimiser for potential continuation of training
110
- return θ, map (identity, opt_stats), st, time_elapsed
107
+ return θ, map (identity, opt_stats), st
111
108
end
0 commit comments