|
| 1 | +####################################################### |
| 2 | +# training loop for variational objectives |
| 3 | +####################################################### |
| 4 | +function pm_next!(pm, stats::NamedTuple) |
| 5 | + return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) |
| 6 | +end |
| 7 | + |
| 8 | +_wrap_in_DI_context(args...) = DifferentiationInterface.Constant.([args...]) |
| 9 | + |
| 10 | +function _prepare_gradient(loss, adbackend, θ, args...) |
| 11 | + if isempty(args...) |
| 12 | + return DifferentiationInterface.prepare_gradient(loss, adbackend, θ) |
| 13 | + end |
| 14 | + return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...) |
| 15 | +end |
| 16 | + |
| 17 | +function _value_and_gradient(loss, prep, adbackend, θ, args...) |
| 18 | + if isempty(args...) |
| 19 | + return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ) |
| 20 | + end |
| 21 | + return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...) |
| 22 | +end |
| 23 | + |
| 24 | + |
| 25 | +""" |
| 26 | + optimize( |
| 27 | + ad::ADTypes.AbstractADType, |
| 28 | + loss, |
| 29 | + θ₀::AbstractVector{T}, |
| 30 | + re, |
| 31 | + args...; |
| 32 | + kwargs... |
| 33 | + ) |
| 34 | +
|
| 35 | +Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!` |
| 36 | + and using the given `optimiser` to compute the steps. |
| 37 | +
|
| 38 | +# Arguments |
| 39 | +- `ad::ADTypes.AbstractADType`: automatic differentiation backend |
| 40 | +- `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised |
| 41 | +- `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters) |
| 42 | +- `re`: reconstruction function that maps the flattened parameters to the normalizing flow |
| 43 | +- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant) |
| 44 | +
|
| 45 | +
|
| 46 | +# Keyword Arguments |
| 47 | +- `max_iters::Int=10000`: maximum number of iterations |
| 48 | +- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps |
| 49 | +- `show_progress::Bool=true`: whether to show the progress bar. The default |
| 50 | + information printed in the progress bar is the iteration number, the loss value, |
| 51 | + and the gradient norm. |
| 52 | +- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)` |
| 53 | + which returns a dictionary-like object of statistics to be displayed in the progress bar. |
| 54 | + re and θ are used for reconstructing the normalizing flow in case that user |
| 55 | + want to further axamine the status of the flow. |
| 56 | +- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the |
| 57 | + training has converged. The default is to always return false. |
| 58 | +- `prog=ProgressMeter.Progress( |
| 59 | + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress |
| 60 | + )`: progress bar configuration |
| 61 | +
|
| 62 | +# Returns |
| 63 | +- `θ`: trained parameters of the normalizing flow |
| 64 | +- `opt_stats`: statistics of the optimiser |
| 65 | +- `st`: optimiser state for potential continuation of training |
| 66 | +""" |
| 67 | +function optimize( |
| 68 | + adbackend, |
| 69 | + loss::Function, |
| 70 | + θ₀::AbstractVector{<:Real}, |
| 71 | + reconstruct::Function, |
| 72 | + args...; |
| 73 | + max_iters::Int=10000, |
| 74 | + optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), |
| 75 | + show_progress::Bool=true, |
| 76 | + callback=nothing, |
| 77 | + hasconverged=(i, stats, re, θ, st) -> false, |
| 78 | + prog=ProgressMeter.Progress( |
| 79 | + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress |
| 80 | + ), |
| 81 | +) |
| 82 | + time_elapsed = @elapsed begin |
| 83 | + opt_stats = [] |
| 84 | + |
| 85 | + # prepare loss and autograd |
| 86 | + θ = copy(θ₀) |
| 87 | + # grad = similar(θ) |
| 88 | + prep = _prepare_gradient(loss, adbackend, θ₀, args...) |
| 89 | + |
| 90 | + |
| 91 | + # initialise optimiser state |
| 92 | + st = Optimisers.setup(optimiser, θ) |
| 93 | + |
| 94 | + # general `hasconverged(...)` approach to allow early termination. |
| 95 | + converged = false |
| 96 | + i = 1 |
| 97 | + while (i ≤ max_iters) && !converged |
| 98 | + # ls, g = DifferentiationInterface.value_and_gradient!(loss, grad, prep, adbackend, θ) |
| 99 | + ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) |
| 100 | + |
| 101 | + # Save stats |
| 102 | + stat = (iteration=i, loss=ls, gradient_norm=norm(g)) |
| 103 | + |
| 104 | + # callback |
| 105 | + if !isnothing(callback) |
| 106 | + new_stat = callback(i, opt_stats, reconstruct, θ) |
| 107 | + stat = !isnothing(new_stat) ? merge(stat, new_stat) : stat |
| 108 | + end |
| 109 | + push!(opt_stats, stat) |
| 110 | + |
| 111 | + # update optimiser state and parameters |
| 112 | + st, θ = Optimisers.update!(st, θ, g) |
| 113 | + |
| 114 | + # check convergence |
| 115 | + i += 1 |
| 116 | + converged = hasconverged(i, stat, reconstruct, θ, st) |
| 117 | + pm_next!(prog, stat) |
| 118 | + end |
| 119 | + end |
| 120 | + # return status of the optimiser for potential continuation of training |
| 121 | + return θ, map(identity, opt_stats), st, time_elapsed |
| 122 | +end |
0 commit comments