diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 812f9b73..f21ba541 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,15 +1,18 @@ name: CI + on: push: branches: - main tags: ['*'] pull_request: + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -19,17 +22,17 @@ jobs: matrix: version: - '1' - - '1.6' + - 'min' os: - ubuntu-latest arch: - x64 steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 diff --git a/Project.toml b/Project.toml index 0ea94d38..0b5ec051 100644 --- a/Project.toml +++ b/Project.toml @@ -1,50 +1,26 @@ name = "NormalizingFlows" uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256" -version = "0.1.1" +version = "0.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -[weakdeps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[extensions] -NormalizingFlowsEnzymeExt = "Enzyme" -NormalizingFlowsForwardDiffExt = "ForwardDiff" -NormalizingFlowsReverseDiffExt = "ReverseDiff" -NormalizingFlowsZygoteExt = "Zygote" - [compat] -ADTypes = "0.1, 0.2, 1" -Bijectors = "0.12.6, 0.13, 0.14" -DiffResults = "1" +ADTypes = "1" +Bijectors = "0.12.6, 0.13, 0.14, 0.15" +DifferentiationInterface = "0.6.42" Distributions = "0.25" DocStringExtensions = "0.9" -Enzyme = "0.11, 0.12, 0.13" -ForwardDiff = "0.10.25" -Optimisers = "0.2.16, 0.3" +Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.0.0" -Requires = "1" -ReverseDiff = "1.14" StatsBase = "0.33, 0.34" -Zygote = "0.6" -julia = "1.6" - -[extras] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +julia = "1.10" diff --git a/README.md b/README.md index 918449ff..46ac6262 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) -**Last updated: 2023-Aug-23** +**Last updated: 2025-Mar-04** A normalizing flow library for Julia. @@ -21,7 +21,7 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor To install the package, run the following command in the Julia REPL: ```julia ] # enter Pkg mode -(@v1.9) pkg> add git@github.com:TuringLang/NormalizingFlows.jl.git +(@v1.11) pkg> add NormalizingFlows ``` Then simply run the following command to use the package: ```julia @@ -29,8 +29,8 @@ using NormalizingFlows ``` ## Quick recap of normalizing flows -Normalizing flows transform a simple reference distribution $q_0$ (sometimes known as base distribution) to -a complex distribution $q$ using invertible functions. +Normalizing flows transform a simple reference distribution $q_0$ (sometimes referred to as the base distribution) +to a complex distribution $q$ using invertible functions. In more details, given the base distribution, usually a standard Gaussian distribution, i.e., $q_0 = \mathcal{N}(0, I)$, we apply a series of parameterized invertible transformations (called flow layers), $T_{1, \theta_1}, \cdots, T_{N, \theta_k}$, yielding that @@ -56,7 +56,7 @@ Given the feasibility of i.i.d. sampling and density evaluation, normalizing flo \text{Reverse KL:}\quad &\arg\min _{\theta} \mathbb{E}_{q_{\theta}}\left[\log q_{\theta}(Z)-\log p(Z)\right] \\ &= \arg\min _{\theta} \mathbb{E}_{q_0}\left[\log \frac{q_\theta(T_N\circ \cdots \circ T_1(Z_0))}{p(T_N\circ \cdots \circ T_1(Z_0))}\right] \\ -&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(F_n \circ \cdots \circ F_1(X)\right)\right] +&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(Z_0)+\sum_{n=1}^N \log J_n\left(T_n \circ \cdots \circ T_1(Z_0)\right)\right] \end{aligned} ``` and @@ -76,10 +76,12 @@ normalizing constant. In contrast, forward KL minimization is typically used for **generative modeling**, where one wants to learn the underlying distribution of some data. -## Current status and TODOs +## Current status and to-dos - [x] general interface development - [x] documentation +- [ ] integrating [Lux.jl](https://lux.csail.mit.edu/stable/tutorials/intermediate/7_RealNVP) and [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl). +This could potentially solve the GPU compatibility issue as well. - [ ] including more NF examples/Tutorials - WIP: [PR#11](https://github.com/TuringLang/NormalizingFlows.jl/pull/11) - [ ] GPU compatibility diff --git a/docs/src/api.md b/docs/src/api.md index f8028b91..eb128863 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows: ```@julia using Distributions, Bijectors T= Float32 +@leaf MvNormal # to prevent params in q₀ from being optimized q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2))) ``` @@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood ```@docs NormalizingFlows.optimize ``` - - -## Utility Functions for Taking Gradient -```@docs -NormalizingFlows.grad! -NormalizingFlows.value_and_gradient! -``` - diff --git a/docs/src/example.md b/docs/src/example.md index 346c15a0..01a9a671 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a ```julia using Bijectors, FunctionChains +using Functors function create_planar_flow(n_layers::Int, q₀) d = length(q₀) @@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀) end # create a 20-layer planar flow -flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) +@leaf MvNormal # to prevent params in q₀ from being optimized +q₀ = MvNormal(zeros(Float32, 2), I) +flow = create_planar_flow(20, q₀) flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison ``` *Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl). @@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel= ## Reference -- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning \ No newline at end of file +- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl deleted file mode 100644 index 1b59cad8..00000000 --- a/ext/NormalizingFlowsEnzymeExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module NormalizingFlowsEnzymeExt - -if isdefined(Base, :get_extension) - using Enzyme - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults -else - using ..Enzyme - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults -end - -# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - y = f(θ) - DiffResults.value!(out, y) - ∇θ = DiffResults.gradient(out) - fill!(∇θ, zero(T)) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsForwardDiffExt.jl b/ext/NormalizingFlowsForwardDiffExt.jl deleted file mode 100644 index 500d54f4..00000000 --- a/ext/NormalizingFlowsForwardDiffExt.jl +++ /dev/null @@ -1,28 +0,0 @@ -module NormalizingFlowsForwardDiffExt - -if isdefined(Base, :get_extension) - using ForwardDiff - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults -else - using ..ForwardDiff - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults -end - -# extract chunk size from AutoForwardDiff -getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - chunk_size = getchunksize(ad) - config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, θ) - else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) - end - ForwardDiff.gradient!(out, f, θ, config) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsReverseDiffExt.jl b/ext/NormalizingFlowsReverseDiffExt.jl deleted file mode 100644 index 1bd39dc4..00000000 --- a/ext/NormalizingFlowsReverseDiffExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module NormalizingFlowsReverseDiffExt - -if isdefined(Base, :get_extension) - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults - using ReverseDiff -else - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults - using ..ReverseDiff -end - -# ReverseDiff without compiled tape -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsZygoteExt.jl b/ext/NormalizingFlowsZygoteExt.jl deleted file mode 100644 index 0eee943c..00000000 --- a/ext/NormalizingFlowsZygoteExt.jl +++ /dev/null @@ -1,23 +0,0 @@ -module NormalizingFlowsZygoteExt - -if isdefined(Base, :get_extension) - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults - using Zygote -else - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults - using ..Zygote -end - -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - y, back = Zygote.pullback(f, θ) - ∇θ = back(one(T)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, first(∇θ)) - return out -end - -end \ No newline at end of file diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 16efb899..709f7586 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -4,14 +4,12 @@ using Bijectors using Optimisers using LinearAlgebra, Random, Distributions, StatsBase using ProgressMeter -using ADTypes, DiffResults +using ADTypes +import DifferentiationInterface as DI using DocStringExtensions -export train_flow, elbo, loglikelihood, value_and_gradient! - -using ADTypes -using DiffResults +export train_flow, elbo, loglikelihood """ train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) @@ -30,7 +28,13 @@ Train the given normalizing flow `flow` by calling `optimize`. - `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps - `ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()`: automatic differentiation backend, currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`. + `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, + `ADTypes.AutoMooncake()` and + `ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + )`. + If user wants to use `AutoEnzyme`, please make sure to include the `set_runtime_activity` and `function_annotation` as shown above. - `kwargs...`: additional keyword arguments for `optimize` (See [`optimize`](@ref) for details) # Returns @@ -57,13 +61,15 @@ function train_flow( # otherwise the compilation time for destructure will be too long θ_flat, re = Optimisers.destructure(flow) + loss(θ, rng, args...) = -vo(rng, re(θ), args...) + # Normalizing flow training loop θ_flat_trained, opt_stats, st = optimize( - rng, ADbackend, - vo, + loss, θ_flat, re, + rng, args...; max_iters=max_iters, optimiser=optimiser, @@ -74,29 +80,7 @@ function train_flow( return flow_trained, opt_stats, st end -include("train.jl") +include("optimize.jl") include("objectives.jl") -# optional dependencies -if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base - using Requires -end - -# Question: should Exts be loaded here or in train.jl? -function __init__() - @static if !isdefined(Base, :get_extension) - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "../ext/NormalizingFlowsForwardDiffExt.jl" - ) - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "../ext/NormalizingFlowsReverseDiffExt.jl" - ) - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include( - "../ext/NormalizingFlowsEnzymeExt.jl" - ) - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include( - "../ext/NormalizingFlowsZygoteExt.jl" - ) - end -end end diff --git a/src/objectives.jl b/src/objectives.jl index ddf129bf..1d7ac5a2 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -1,2 +1,2 @@ include("objectives/elbo.jl") -include("objectives/loglikelihood.jl") \ No newline at end of file +include("objectives/loglikelihood.jl") # not fully tested diff --git a/src/objectives/elbo.jl b/src/objectives/elbo.jl index 68545b54..2751ed90 100644 --- a/src/objectives/elbo.jl +++ b/src/objectives/elbo.jl @@ -42,4 +42,4 @@ end function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples) return elbo(Random.default_rng(), flow, logp, n_samples) -end \ No newline at end of file +end diff --git a/src/objectives/loglikelihood.jl b/src/objectives/loglikelihood.jl index 4097ae15..ab5eb961 100644 --- a/src/objectives/loglikelihood.jl +++ b/src/objectives/loglikelihood.jl @@ -2,12 +2,13 @@ # training by minimizing forward KL (MLE) #################################### """ - loglikelihood(flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat) + loglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat) Compute the log-likelihood for variational distribution flow at a batch of samples xs from the target distribution p. # Arguments +- `rng`: random number generator (empty argument, only needed to ensure the same signature as other variational objectives) - `flow`: variational distribution to be trained. In particular "flow = transformed(q₀, T::Bijectors.Bijector)", q₀ is a reference distribution that one can easily sample and compute logpdf @@ -15,6 +16,7 @@ the target distribution p. """ function loglikelihood( + ::AbstractRNG, # empty argument flow::Bijectors.UnivariateTransformed, # variational distribution to be trained xs::AbstractVector, # sample batch from target dist p ) @@ -22,9 +24,20 @@ function loglikelihood( end function loglikelihood( + ::AbstractRNG, # empty argument flow::Bijectors.MultivariateTransformed, # variational distribution to be trained xs::AbstractMatrix, # sample batch from target dist p ) llhs = map(x -> logpdf(flow, x), eachcol(xs)) return mean(llhs) -end \ No newline at end of file +end + +## TODO:will need to implement the version that takes a dataloader +# function loglikelihood( +# rng::AbstractRNG, +# flow::Bijectors.TransformedDistribution, +# dataloader +# ) +# xs = dataloader(rng) +# return loglikelihood(rng, flow, collect(dataloader)) +# end diff --git a/src/optimize.jl b/src/optimize.jl new file mode 100644 index 00000000..b4adad91 --- /dev/null +++ b/src/optimize.jl @@ -0,0 +1,108 @@ +####################################################### +# training loop for variational objectives +####################################################### +function pm_next!(pm, stats::NamedTuple) + return ProgressMeter.next!(pm; showvalues=map(tuple, keys(stats), values(stats))) +end + +function _prepare_gradient(loss, adbackend, θ, args...) + return DI.prepare_gradient(loss, adbackend, θ, map(DI.Constant, args)...) +end + +function _value_and_gradient(loss, prep, adbackend, θ, args...) + return DI.value_and_gradient(loss, prep, adbackend, θ, map(DI.Constant, args)...) +end + +""" + optimize( + ad::ADTypes.AbstractADType, + loss, + θ₀::AbstractVector{T}, + re, + args...; + kwargs... + ) + +Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!` + and using the given `optimiser` to compute the steps. + +# Arguments +- `ad::ADTypes.AbstractADType`: automatic differentiation backend +- `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised +- `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters) +- `re`: reconstruction function that maps the flattened parameters to the normalizing flow +- `args...`: additional arguments for `loss` (will be set as DI.Constant) + +# Keyword Arguments +- `max_iters::Int=10000`: maximum number of iterations +- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps +- `show_progress::Bool=true`: whether to show the progress bar. The default + information printed in the progress bar is the iteration number, the loss value, + and the gradient norm. +- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)` + which returns a dictionary-like object of statistics to be displayed in the progress bar. + re and θ are used for reconstructing the normalizing flow in case that user + want to further axamine the status of the flow. +- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the + training has converged. The default is to always return false. +- `prog=ProgressMeter.Progress( + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress + )`: progress bar configuration + +# Returns +- `θ`: trained parameters of the normalizing flow +- `opt_stats`: statistics of the optimiser +- `st`: optimiser state for potential continuation of training +""" +function optimize( + adbackend, + loss, + θ₀::AbstractVector{<:Real}, + reconstruct, + args...; + max_iters::Int=10000, + optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), + show_progress::Bool=true, + callback=nothing, + hasconverged=(i, stats, re, θ, st) -> false, + prog=ProgressMeter.Progress( + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress + ), +) + opt_stats = [] + + # prepare loss and autograd + θ = deepcopy(θ₀) + # grad = similar(θ) + prep = _prepare_gradient(loss, adbackend, θ₀, args...) + + # initialise optimiser state + st = Optimisers.setup(optimiser, θ) + + # general `hasconverged(...)` approach to allow early termination. + converged = false + i = 1 + while (i ≤ max_iters) && !converged + ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) + + # Save stats + stat = (iteration=i, loss=ls, gradient_norm=norm(g)) + + # callback + if callback !== nothing + new_stat = callback(i, opt_stats, reconstruct, θ) + stat = new_stat !== nothing ? merge(stat, new_stat) : stat + end + push!(opt_stats, stat) + + # update optimiser state and parameters + st, θ = Optimisers.update!(st, θ, g) + + # check convergence + i += 1 + converged = hasconverged(i, stat, reconstruct, θ, st) + pm_next!(prog, stat) + end + # return status of the optimiser for potential continuation of training + return θ, map(identity, opt_stats), st +end diff --git a/src/train.jl b/src/train.jl deleted file mode 100644 index 3a286350..00000000 --- a/src/train.jl +++ /dev/null @@ -1,164 +0,0 @@ -""" - value_and_gradient!( - ad::ADTypes.AbstractADType, - f, - θ::AbstractVector{T}, - out::DiffResults.MutableDiffResult - ) where {T<:Real} - -Compute the value and gradient of a function `f` at `θ` using the automatic -differentiation backend `ad`. The result is stored in `out`. -The function `f` must return a scalar value. The gradient is stored in `out` as a -vector of the same length as `θ`. -""" -function value_and_gradient! end - -""" - grad!( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ_flat::AbstractVector{<:Real}, - reconstruct, - out::DiffResults.MutableDiffResult, - args... - ) - -Compute the value and gradient for negation of the variational objective `vo` -at `θ_flat` using the automatic differentiation backend `ad`. - -Default implementation is provided for `ad` where `ad` is one of `AutoZygote`, -`AutoForwardDiff`, `AutoReverseDiff` (with no compiled tape), and `AutoEnzyme`. -The result is stored in `out`. - -# Arguments -- `rng::AbstractRNG`: random number generator -- `ad::ADTypes.AbstractADType`: automatic differentiation backend, currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`. -- `vo`: variational objective -- `θ_flat::AbstractVector{<:Real}`: flattened parameters of the normalizing flow -- `reconstruct`: function that reconstructs the normalizing flow from the flattened parameters -- `out::DiffResults.MutableDiffResult`: mutable diff result to store the value and gradient -- `args...`: additional arguments for `vo` -""" -function grad!( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ_flat::AbstractVector{<:Real}, - reconstruct, - out::DiffResults.MutableDiffResult, - args...; -) - # define opt loss function - loss(θ_) = -vo(rng, reconstruct(θ_), args...) - # compute loss value and gradient - out = value_and_gradient!(ad, loss, θ_flat, out) - return out -end - -####################################################### -# training loop for variational objectives -####################################################### -function pm_next!(pm, stats::NamedTuple) - return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) -end - -""" - optimize( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ₀::AbstractVector{T}, - re, - args...; - kwargs... - ) - -Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!` - and using the given `optimiser` to compute the steps. - -# Arguments -- `rng::AbstractRNG`: random number generator -- `ad::ADTypes.AbstractADType`: automatic differentiation backend -- `vo`: variational objective -- `θ₀::AbstractVector{T}`: initial parameters of the normalizing flow -- `re`: function that reconstructs the normalizing flow from the flattened parameters -- `args...`: additional arguments for `vo` - - -# Keyword Arguments -- `max_iters::Int=10000`: maximum number of iterations -- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps -- `show_progress::Bool=true`: whether to show the progress bar. The default - information printed in the progress bar is the iteration number, the loss value, - and the gradient norm. -- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)` - which returns a dictionary-like object of statistics to be displayed in the progress bar. - re and θ are used for reconstructing the normalizing flow in case that user - want to further axamine the status of the flow. -- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the - training has converged. The default is to always return false. -- `prog=ProgressMeter.Progress( - max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress - )`: progress bar configuration - -# Returns -- `θ`: trained parameters of the normalizing flow -- `opt_stats`: statistics of the optimiser -- `st`: optimiser state for potential continuation of training -""" -function optimize( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ₀::AbstractVector{<:Real}, - re, - args...; - max_iters::Int=10000, - optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), - show_progress::Bool=true, - callback=nothing, - hasconverged=(i, stats, re, θ, st) -> false, - prog=ProgressMeter.Progress( - max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress - ), -) - opt_stats = [] - - θ = copy(θ₀) - diff_result = DiffResults.GradientResult(θ) - # initialise optimiser state - st = Optimisers.setup(optimiser, θ) - - # general `hasconverged(...)` approach to allow early termination. - converged = false - i = 1 - time_elapsed = @elapsed while (i ≤ max_iters) && !converged - # Compute gradient and objective value; results are stored in `diff_results` - grad!(rng, ad, vo, θ, re, diff_result, args...) - - # Save stats - ls = DiffResults.value(diff_result) - g = DiffResults.gradient(diff_result) - stat = (iteration=i, loss=ls, gradient_norm=norm(g)) - push!(opt_stats, stat) - - # callback - if !isnothing(callback) - new_stat = callback(i, opt_stats, re, θ) - stat = !isnothing(new_stat) ? merge(new_stat, stat) : stat - end - - # update optimiser state and parameters - st, θ = Optimisers.update!(st, θ, DiffResults.gradient(diff_result)) - - # check convergence - i += 1 - converged = hasconverged(i, stat, re, θ, st) - pm_next!(prog, stat) - end - - # return status of the optimiser for potential continuation of training - return θ, map(identity, opt_stats), st -end diff --git a/test/Project.toml b/test/Project.toml index c474adda..be5ffa0e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,13 +1,20 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Mooncake = "0.4.101" diff --git a/test/ad.jl b/test/ad.jl index a394d806..725b3547 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,31 +1,41 @@ -@testset "AD correctness" begin - f(x) = sum(abs2, x) +@testset "DI.AD with context wrapper" begin + f(x, y, z) = sum(abs2, x .+ y .+ z) @testset "$T" for T in [Float32, Float64] x = randn(T, 10) + y = randn(T, 10) + z = randn(T, 10) chunksize = size(x, 1) @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), - ADTypes.AutoEnzyme(), + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ADTypes.AutoMooncake(; config=Mooncake.Config()), ] - out = DiffResults.GradientResult(x) - NormalizingFlows.value_and_gradient!(at, f, x, out) - @test DiffResults.value(out) ≈ f(x) - @test DiffResults.gradient(out) ≈ 2x + prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) + value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) + @test value ≈ f(x, y, z) + @test grad ≈ 2 * (x .+ y .+ z) end end end -@testset "AD for ELBO" begin +@testset "AD for ELBO on mean-field Gaussian VI" begin @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # not working now + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -33,20 +43,33 @@ end target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) + # necessary for Zygote/mooncake to differentiate through the flow + # prevent updating params of q0 + @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) - flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ))) + flow = Bijectors.transformed( + q₀, Bijectors.Shift(zeros(T, 2)) ∘ Bijectors.Scale(ones(T, 2)) + ) - sample_per_iter = 10 θ, re = Optimisers.destructure(flow) - out = DiffResults.GradientResult(θ) # check grad computation for elbo - NormalizingFlows.grad!( - Random.default_rng(), at, elbo, θ, re, out, logp, sample_per_iter + function loss(θ, rng, logp, sample_per_iter) + return -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter) + end + + rng = Random.default_rng() + sample_per_iter = 10 + + prep = NormalizingFlows._prepare_gradient( + loss, at, θ, rng, logp, sample_per_iter + ) + value, grad = NormalizingFlows._value_and_gradient( + loss, prep, at, θ, rng, logp, sample_per_iter ) - @test DiffResults.value(out) != nothing - @test all(DiffResults.gradient(out) .!= nothing) + @test value !== nothing + @test all(grad .!= nothing) end end -end \ No newline at end of file +end diff --git a/test/interface.jl b/test/interface.jl index a3540979..947d4f37 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,33 +1,39 @@ -@testset "learining 2d Gaussian" begin +@testset "testing mean-field Gaussian VI" begin chunksize = 4 @testset "$adtype" for adtype in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # doesn't work for Enzyme + ADTypes.AutoReverseDiff(), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ADTypes.AutoMooncake(; config = Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) Σ = Diagonal(4 * ones(T, 2)) + target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) + @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed( - q₀, Bijectors.Shift(zero.(μ)) ∘ Bijectors.Scale(ones(T, 2)) + q₀, Bijectors.Shift(zeros(T, 2)) ∘ Bijectors.Scale(ones(T, 2)) ) sample_per_iter = 10 - cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) - checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3 + cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) + checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 flow_trained, stats, _ = train_flow( elbo, flow, logp, sample_per_iter; max_iters=5_000, - optimiser=Optimisers.ADAM(0.01 * one(T)), + optimiser=Optimisers.Adam(one(T)/100), ADbackend=adtype, show_progress=false, callback=cb, @@ -44,4 +50,62 @@ @test el_trained > -1 end end -end \ No newline at end of file +end + +# function create_planar_flow(n_layers::Int, q₀, T) +# d = length(q₀) +# if T == Float32 +# Ls = reduce(∘, [f32(PlanarLayer(d)) for _ in 1:n_layers]) +# else +# Ls = reduce(∘, [PlanarLayer(d) for _ in 1:n_layers]) +# end +# return Bijectors.transformed(q₀, Ls) +# end + +# @testset "testing planar flow" begin +# chunksize = 4 +# @testset "$adtype" for adtype in [ +# ADTypes.AutoZygote(), +# ADTypes.AutoForwardDiff(; chunksize=chunksize), +# ADTypes.AutoForwardDiff(), +# ADTypes.AutoReverseDiff(), +# ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), +# ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 +# ] +# @testset "$T" for T in [Float32, Float64] +# μ = 10 * ones(T, 2) +# Σ = Diagonal(4 * ones(T, 2)) + +# target = MvNormal(μ, Σ) +# logp(z) = logpdf(target, z) + +# @leaf MvNormal +# q₀ = MvNormal(zeros(T, 2), ones(T, 2)) +# nlayers = 10 +# flow = create_planar_flow(nlayers, q₀, T) + +# sample_per_iter = 10 +# cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +# checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +# flow_trained, stats, _, _ = train_flow( +# elbo, +# flow, +# logp, +# sample_per_iter; +# max_iters=10_000, +# optimiser=Optimisers.Adam(one(T)/100), +# ADbackend=adtype, +# show_progress=false, +# callback=cb, +# hasconverged=checkconv, +# ) +# θ, re = Optimisers.destructure(flow_trained) + +# el_untrained = elbo(flow, logp, 1000) +# el_trained = elbo(flow_trained, logp, 1000) + +# @test el_trained > el_untrained +# @test el_trained > -1 +# end +# end +# end diff --git a/test/objectives.jl b/test/objectives.jl index 072286d1..4641b3cd 100644 --- a/test/objectives.jl +++ b/test/objectives.jl @@ -9,9 +9,10 @@ flow = Bijectors.transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(sqrt.(Σ))) x = randn(T, 2) + rng = Random.default_rng() @testset "elbo" begin - el = elbo(Random.default_rng(), flow, logp, 10) + el = elbo(rng, flow, logp, 10) @test abs(el) ≤ 1e-5 @test logpdf(flow, x) + el ≈ logp(x) @@ -20,8 +21,8 @@ @testset "likelihood" begin sample_trained = rand(flow, 1000) sample_untrained = rand(q₀, 1000) - llh_trained = NormalizingFlows.loglikelihood(flow, sample_trained) - llh_untrained = NormalizingFlows.loglikelihood(flow, sample_untrained) + llh_trained = NormalizingFlows.loglikelihood(rng, flow, sample_trained) + llh_untrained = NormalizingFlows.loglikelihood(rng, flow, sample_untrained) @test llh_trained > llh_untrained end diff --git a/test/runtests.jl b/test/runtests.jl index e050a645..33a98085 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,10 +3,14 @@ using Distributions using Bijectors, Optimisers using LinearAlgebra using Random -using ADTypes, DiffResults -using ForwardDiff, Zygote, Enzyme, ReverseDiff +using ADTypes +using Functors +using ForwardDiff, Zygote, ReverseDiff, Enzyme, Mooncake +using Flux: f32 +import DifferentiationInterface as DI + using Test include("ad.jl") include("objectives.jl") -include("interface.jl") \ No newline at end of file +include("interface.jl")