Skip to content

Commit 6220e96

Browse files
committed
switch to differentiationinterface from diffresults
1 parent 3e9b668 commit 6220e96

File tree

5 files changed

+142
-37
lines changed

5 files changed

+142
-37
lines changed

src/NormalizingFlows.jl

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ using Bijectors
44
using Optimisers
55
using LinearAlgebra, Random, Distributions, StatsBase
66
using ProgressMeter
7-
using ADTypes, DiffResults
7+
using ADTypes
8+
using DifferentiationInterface
89

910
using DocStringExtensions
1011

11-
export train_flow, elbo, loglikelihood, value_and_gradient!
12-
13-
using ADTypes
14-
using DiffResults
12+
export train_flow, elbo, loglikelihood
1513

1614
"""
1715
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
@@ -56,47 +54,29 @@ function train_flow(
5654
# use FunctionChains instead of simple compositions to construct the flow when many flow layers are involved
5755
# otherwise the compilation time for destructure will be too long
5856
θ_flat, re = Optimisers.destructure(flow)
57+
58+
loss(θ, rng, args...) = -vo(rng, re(θ), args...)
5959

6060
# Normalizing flow training loop
61-
θ_flat_trained, opt_stats, st = optimize(
62-
rng,
61+
θ_flat_trained, opt_stats, st, time_elapsed = optimize(
6362
ADbackend,
64-
vo,
63+
loss,
6564
θ_flat,
66-
re,
67-
args...;
65+
re,
66+
(rng, args...)...;
6867
max_iters=max_iters,
6968
optimiser=optimiser,
7069
kwargs...,
7170
)
7271

7372
flow_trained = re(θ_flat_trained)
74-
return flow_trained, opt_stats, st
73+
return flow_trained, opt_stats, st, time_elapsed
7574
end
7675

77-
include("train.jl")
76+
77+
78+
include("optimize.jl")
7879
include("objectives.jl")
7980

80-
# optional dependencies
81-
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
82-
using Requires
83-
end
8481

85-
# Question: should Exts be loaded here or in train.jl?
86-
function __init__()
87-
@static if !isdefined(Base, :get_extension)
88-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
89-
"../ext/NormalizingFlowsForwardDiffExt.jl"
90-
)
91-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
92-
"../ext/NormalizingFlowsReverseDiffExt.jl"
93-
)
94-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
95-
"../ext/NormalizingFlowsEnzymeExt.jl"
96-
)
97-
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
98-
"../ext/NormalizingFlowsZygoteExt.jl"
99-
)
100-
end
101-
end
10282
end

src/objectives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
include("objectives/elbo.jl")
2-
include("objectives/loglikelihood.jl")
2+
include("objectives/loglikelihood.jl") # not tested

src/objectives/elbo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ end
4242

4343
function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
4444
return elbo(Random.default_rng(), flow, logp, n_samples)
45-
end
45+
end

src/objectives/loglikelihood.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,32 @@
22
# training by minimizing forward KL (MLE)
33
####################################
44
"""
5-
loglikelihood(flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)
5+
loglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)
66
77
Compute the log-likelihood for variational distribution flow at a batch of samples xs from
88
the target distribution p.
99
1010
# Arguments
11+
- `rng`: random number generator (empty argument, only needed to ensure the same signature as other variational objectives)
1112
- `flow`: variational distribution to be trained. In particular
1213
"flow = transformed(q₀, T::Bijectors.Bijector)",
1314
q₀ is a reference distribution that one can easily sample and compute logpdf
1415
- `xs`: samples from the target distribution p.
1516
1617
"""
1718
function loglikelihood(
19+
rng::AbstractRNG, # empty argument
1820
flow::Bijectors.UnivariateTransformed, # variational distribution to be trained
1921
xs::AbstractVector, # sample batch from target dist p
2022
)
2123
return mean(Base.Fix1(logpdf, flow), xs)
2224
end
2325

2426
function loglikelihood(
27+
rng::AbstractRNG, # empty argument
2528
flow::Bijectors.MultivariateTransformed, # variational distribution to be trained
2629
xs::AbstractMatrix, # sample batch from target dist p
2730
)
2831
llhs = map(x -> logpdf(flow, x), eachcol(xs))
2932
return mean(llhs)
30-
end
33+
end

src/optimize.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#######################################################
2+
# training loop for variational objectives
3+
#######################################################
4+
function pm_next!(pm, stats::NamedTuple)
5+
return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
6+
end
7+
8+
_wrap_in_DI_context(args...) = DifferentiationInterface.Constant.([args...])
9+
10+
function _prepare_gradient(loss, adbackend, θ, args...)
11+
if isempty(args...)
12+
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ)
13+
end
14+
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...)
15+
end
16+
17+
function _value_and_gradient(loss, prep, adbackend, θ, args...)
18+
if isempty(args...)
19+
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ)
20+
end
21+
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...)
22+
end
23+
24+
25+
"""
26+
optimize(
27+
ad::ADTypes.AbstractADType,
28+
loss,
29+
θ₀::AbstractVector{T},
30+
re,
31+
args...;
32+
kwargs...
33+
)
34+
35+
Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!`
36+
and using the given `optimiser` to compute the steps.
37+
38+
# Arguments
39+
- `ad::ADTypes.AbstractADType`: automatic differentiation backend
40+
- `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised
41+
- `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters)
42+
- `re`: reconstruction function that maps the flattened parameters to the normalizing flow
43+
- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant)
44+
45+
46+
# Keyword Arguments
47+
- `max_iters::Int=10000`: maximum number of iterations
48+
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
49+
- `show_progress::Bool=true`: whether to show the progress bar. The default
50+
information printed in the progress bar is the iteration number, the loss value,
51+
and the gradient norm.
52+
- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)`
53+
which returns a dictionary-like object of statistics to be displayed in the progress bar.
54+
re and θ are used for reconstructing the normalizing flow in case that user
55+
want to further axamine the status of the flow.
56+
- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the
57+
training has converged. The default is to always return false.
58+
- `prog=ProgressMeter.Progress(
59+
max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress
60+
)`: progress bar configuration
61+
62+
# Returns
63+
- `θ`: trained parameters of the normalizing flow
64+
- `opt_stats`: statistics of the optimiser
65+
- `st`: optimiser state for potential continuation of training
66+
"""
67+
function optimize(
68+
adbackend,
69+
loss::Function,
70+
θ₀::AbstractVector{<:Real},
71+
reconstruct::Function,
72+
args...;
73+
max_iters::Int=10000,
74+
optimiser::Optimisers.AbstractRule=Optimisers.ADAM(),
75+
show_progress::Bool=true,
76+
callback=nothing,
77+
hasconverged=(i, stats, re, θ, st) -> false,
78+
prog=ProgressMeter.Progress(
79+
max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress
80+
),
81+
)
82+
time_elapsed = @elapsed begin
83+
opt_stats = []
84+
85+
# prepare loss and autograd
86+
θ = copy(θ₀)
87+
# grad = similar(θ)
88+
prep = _prepare_gradient(loss, adbackend, θ₀, args...)
89+
90+
91+
# initialise optimiser state
92+
st = Optimisers.setup(optimiser, θ)
93+
94+
# general `hasconverged(...)` approach to allow early termination.
95+
converged = false
96+
i = 1
97+
while (i max_iters) && !converged
98+
# ls, g = DifferentiationInterface.value_and_gradient!(loss, grad, prep, adbackend, θ)
99+
ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...)
100+
101+
# Save stats
102+
stat = (iteration=i, loss=ls, gradient_norm=norm(g))
103+
104+
# callback
105+
if !isnothing(callback)
106+
new_stat = callback(i, opt_stats, reconstruct, θ)
107+
stat = !isnothing(new_stat) ? merge(stat, new_stat) : stat
108+
end
109+
push!(opt_stats, stat)
110+
111+
# update optimiser state and parameters
112+
st, θ = Optimisers.update!(st, θ, g)
113+
114+
# check convergence
115+
i += 1
116+
converged = hasconverged(i, stat, reconstruct, θ, st)
117+
pm_next!(prog, stat)
118+
end
119+
end
120+
# return status of the optimiser for potential continuation of training
121+
return θ, map(identity, opt_stats), st, time_elapsed
122+
end

0 commit comments

Comments
 (0)