From 6220e96ee270d4de8b38c835da2c4f44df8dfe44 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:12:23 -0800 Subject: [PATCH 01/43] switch to differentiationinterface from diffresults --- src/NormalizingFlows.jl | 46 ++++-------- src/objectives.jl | 2 +- src/objectives/elbo.jl | 2 +- src/objectives/loglikelihood.jl | 7 +- src/optimize.jl | 122 ++++++++++++++++++++++++++++++++ 5 files changed, 142 insertions(+), 37 deletions(-) create mode 100644 src/optimize.jl diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 16efb899..89e0ce3c 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -4,14 +4,12 @@ using Bijectors using Optimisers using LinearAlgebra, Random, Distributions, StatsBase using ProgressMeter -using ADTypes, DiffResults +using ADTypes +using DifferentiationInterface using DocStringExtensions -export train_flow, elbo, loglikelihood, value_and_gradient! - -using ADTypes -using DiffResults +export train_flow, elbo, loglikelihood """ train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) @@ -56,47 +54,29 @@ function train_flow( # use FunctionChains instead of simple compositions to construct the flow when many flow layers are involved # otherwise the compilation time for destructure will be too long θ_flat, re = Optimisers.destructure(flow) + + loss(θ, rng, args...) = -vo(rng, re(θ), args...) # Normalizing flow training loop - θ_flat_trained, opt_stats, st = optimize( - rng, + θ_flat_trained, opt_stats, st, time_elapsed = optimize( ADbackend, - vo, + loss, θ_flat, - re, - args...; + re, + (rng, args...)...; max_iters=max_iters, optimiser=optimiser, kwargs..., ) flow_trained = re(θ_flat_trained) - return flow_trained, opt_stats, st + return flow_trained, opt_stats, st, time_elapsed end -include("train.jl") + + +include("optimize.jl") include("objectives.jl") -# optional dependencies -if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base - using Requires -end -# Question: should Exts be loaded here or in train.jl? -function __init__() - @static if !isdefined(Base, :get_extension) - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "../ext/NormalizingFlowsForwardDiffExt.jl" - ) - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "../ext/NormalizingFlowsReverseDiffExt.jl" - ) - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include( - "../ext/NormalizingFlowsEnzymeExt.jl" - ) - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include( - "../ext/NormalizingFlowsZygoteExt.jl" - ) - end -end end diff --git a/src/objectives.jl b/src/objectives.jl index ddf129bf..e5df463b 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -1,2 +1,2 @@ include("objectives/elbo.jl") -include("objectives/loglikelihood.jl") \ No newline at end of file +include("objectives/loglikelihood.jl") # not tested diff --git a/src/objectives/elbo.jl b/src/objectives/elbo.jl index 68545b54..2751ed90 100644 --- a/src/objectives/elbo.jl +++ b/src/objectives/elbo.jl @@ -42,4 +42,4 @@ end function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples) return elbo(Random.default_rng(), flow, logp, n_samples) -end \ No newline at end of file +end diff --git a/src/objectives/loglikelihood.jl b/src/objectives/loglikelihood.jl index 4097ae15..861cd1bc 100644 --- a/src/objectives/loglikelihood.jl +++ b/src/objectives/loglikelihood.jl @@ -2,12 +2,13 @@ # training by minimizing forward KL (MLE) #################################### """ - loglikelihood(flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat) + loglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat) Compute the log-likelihood for variational distribution flow at a batch of samples xs from the target distribution p. # Arguments +- `rng`: random number generator (empty argument, only needed to ensure the same signature as other variational objectives) - `flow`: variational distribution to be trained. In particular "flow = transformed(q₀, T::Bijectors.Bijector)", q₀ is a reference distribution that one can easily sample and compute logpdf @@ -15,6 +16,7 @@ the target distribution p. """ function loglikelihood( + rng::AbstractRNG, # empty argument flow::Bijectors.UnivariateTransformed, # variational distribution to be trained xs::AbstractVector, # sample batch from target dist p ) @@ -22,9 +24,10 @@ function loglikelihood( end function loglikelihood( + rng::AbstractRNG, # empty argument flow::Bijectors.MultivariateTransformed, # variational distribution to be trained xs::AbstractMatrix, # sample batch from target dist p ) llhs = map(x -> logpdf(flow, x), eachcol(xs)) return mean(llhs) -end \ No newline at end of file +end diff --git a/src/optimize.jl b/src/optimize.jl new file mode 100644 index 00000000..a4c02370 --- /dev/null +++ b/src/optimize.jl @@ -0,0 +1,122 @@ +####################################################### +# training loop for variational objectives +####################################################### +function pm_next!(pm, stats::NamedTuple) + return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end + +_wrap_in_DI_context(args...) = DifferentiationInterface.Constant.([args...]) + +function _prepare_gradient(loss, adbackend, θ, args...) + if isempty(args...) + return DifferentiationInterface.prepare_gradient(loss, adbackend, θ) + end + return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...) +end + +function _value_and_gradient(loss, prep, adbackend, θ, args...) + if isempty(args...) + return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ) + end + return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...) +end + + +""" + optimize( + ad::ADTypes.AbstractADType, + loss, + θ₀::AbstractVector{T}, + re, + args...; + kwargs... + ) + +Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!` + and using the given `optimiser` to compute the steps. + +# Arguments +- `ad::ADTypes.AbstractADType`: automatic differentiation backend +- `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised +- `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters) +- `re`: reconstruction function that maps the flattened parameters to the normalizing flow +- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant) + + +# Keyword Arguments +- `max_iters::Int=10000`: maximum number of iterations +- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps +- `show_progress::Bool=true`: whether to show the progress bar. The default + information printed in the progress bar is the iteration number, the loss value, + and the gradient norm. +- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)` + which returns a dictionary-like object of statistics to be displayed in the progress bar. + re and θ are used for reconstructing the normalizing flow in case that user + want to further axamine the status of the flow. +- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the + training has converged. The default is to always return false. +- `prog=ProgressMeter.Progress( + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress + )`: progress bar configuration + +# Returns +- `θ`: trained parameters of the normalizing flow +- `opt_stats`: statistics of the optimiser +- `st`: optimiser state for potential continuation of training +""" +function optimize( + adbackend, + loss::Function, + θ₀::AbstractVector{<:Real}, + reconstruct::Function, + args...; + max_iters::Int=10000, + optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), + show_progress::Bool=true, + callback=nothing, + hasconverged=(i, stats, re, θ, st) -> false, + prog=ProgressMeter.Progress( + max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress + ), +) + time_elapsed = @elapsed begin + opt_stats = [] + + # prepare loss and autograd + θ = copy(θ₀) + # grad = similar(θ) + prep = _prepare_gradient(loss, adbackend, θ₀, args...) + + + # initialise optimiser state + st = Optimisers.setup(optimiser, θ) + + # general `hasconverged(...)` approach to allow early termination. + converged = false + i = 1 + while (i ≤ max_iters) && !converged + # ls, g = DifferentiationInterface.value_and_gradient!(loss, grad, prep, adbackend, θ) + ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) + + # Save stats + stat = (iteration=i, loss=ls, gradient_norm=norm(g)) + + # callback + if !isnothing(callback) + new_stat = callback(i, opt_stats, reconstruct, θ) + stat = !isnothing(new_stat) ? merge(stat, new_stat) : stat + end + push!(opt_stats, stat) + + # update optimiser state and parameters + st, θ = Optimisers.update!(st, θ, g) + + # check convergence + i += 1 + converged = hasconverged(i, stat, reconstruct, θ, st) + pm_next!(prog, stat) + end + end + # return status of the optimiser for potential continuation of training + return θ, map(identity, opt_stats), st, time_elapsed +end From 7b4fb85d3f286a868fa5f89aa6aae19ea515f7bb Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:13:03 -0800 Subject: [PATCH 02/43] rename train.jl to optimize.jl --- src/train.jl | 164 --------------------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 src/train.jl diff --git a/src/train.jl b/src/train.jl deleted file mode 100644 index 3a286350..00000000 --- a/src/train.jl +++ /dev/null @@ -1,164 +0,0 @@ -""" - value_and_gradient!( - ad::ADTypes.AbstractADType, - f, - θ::AbstractVector{T}, - out::DiffResults.MutableDiffResult - ) where {T<:Real} - -Compute the value and gradient of a function `f` at `θ` using the automatic -differentiation backend `ad`. The result is stored in `out`. -The function `f` must return a scalar value. The gradient is stored in `out` as a -vector of the same length as `θ`. -""" -function value_and_gradient! end - -""" - grad!( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ_flat::AbstractVector{<:Real}, - reconstruct, - out::DiffResults.MutableDiffResult, - args... - ) - -Compute the value and gradient for negation of the variational objective `vo` -at `θ_flat` using the automatic differentiation backend `ad`. - -Default implementation is provided for `ad` where `ad` is one of `AutoZygote`, -`AutoForwardDiff`, `AutoReverseDiff` (with no compiled tape), and `AutoEnzyme`. -The result is stored in `out`. - -# Arguments -- `rng::AbstractRNG`: random number generator -- `ad::ADTypes.AbstractADType`: automatic differentiation backend, currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`. -- `vo`: variational objective -- `θ_flat::AbstractVector{<:Real}`: flattened parameters of the normalizing flow -- `reconstruct`: function that reconstructs the normalizing flow from the flattened parameters -- `out::DiffResults.MutableDiffResult`: mutable diff result to store the value and gradient -- `args...`: additional arguments for `vo` -""" -function grad!( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ_flat::AbstractVector{<:Real}, - reconstruct, - out::DiffResults.MutableDiffResult, - args...; -) - # define opt loss function - loss(θ_) = -vo(rng, reconstruct(θ_), args...) - # compute loss value and gradient - out = value_and_gradient!(ad, loss, θ_flat, out) - return out -end - -####################################################### -# training loop for variational objectives -####################################################### -function pm_next!(pm, stats::NamedTuple) - return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) -end - -""" - optimize( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ₀::AbstractVector{T}, - re, - args...; - kwargs... - ) - -Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by calling `grad!` - and using the given `optimiser` to compute the steps. - -# Arguments -- `rng::AbstractRNG`: random number generator -- `ad::ADTypes.AbstractADType`: automatic differentiation backend -- `vo`: variational objective -- `θ₀::AbstractVector{T}`: initial parameters of the normalizing flow -- `re`: function that reconstructs the normalizing flow from the flattened parameters -- `args...`: additional arguments for `vo` - - -# Keyword Arguments -- `max_iters::Int=10000`: maximum number of iterations -- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps -- `show_progress::Bool=true`: whether to show the progress bar. The default - information printed in the progress bar is the iteration number, the loss value, - and the gradient norm. -- `callback=nothing`: callback function with signature `cb(iter, opt_state, re, θ)` - which returns a dictionary-like object of statistics to be displayed in the progress bar. - re and θ are used for reconstructing the normalizing flow in case that user - want to further axamine the status of the flow. -- `hasconverged = (iter, opt_stats, re, θ, st) -> false`: function that checks whether the - training has converged. The default is to always return false. -- `prog=ProgressMeter.Progress( - max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress - )`: progress bar configuration - -# Returns -- `θ`: trained parameters of the normalizing flow -- `opt_stats`: statistics of the optimiser -- `st`: optimiser state for potential continuation of training -""" -function optimize( - rng::AbstractRNG, - ad::ADTypes.AbstractADType, - vo, - θ₀::AbstractVector{<:Real}, - re, - args...; - max_iters::Int=10000, - optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), - show_progress::Bool=true, - callback=nothing, - hasconverged=(i, stats, re, θ, st) -> false, - prog=ProgressMeter.Progress( - max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress - ), -) - opt_stats = [] - - θ = copy(θ₀) - diff_result = DiffResults.GradientResult(θ) - # initialise optimiser state - st = Optimisers.setup(optimiser, θ) - - # general `hasconverged(...)` approach to allow early termination. - converged = false - i = 1 - time_elapsed = @elapsed while (i ≤ max_iters) && !converged - # Compute gradient and objective value; results are stored in `diff_results` - grad!(rng, ad, vo, θ, re, diff_result, args...) - - # Save stats - ls = DiffResults.value(diff_result) - g = DiffResults.gradient(diff_result) - stat = (iteration=i, loss=ls, gradient_norm=norm(g)) - push!(opt_stats, stat) - - # callback - if !isnothing(callback) - new_stat = callback(i, opt_stats, re, θ) - stat = !isnothing(new_stat) ? merge(new_stat, stat) : stat - end - - # update optimiser state and parameters - st, θ = Optimisers.update!(st, θ, DiffResults.gradient(diff_result)) - - # check convergence - i += 1 - converged = hasconverged(i, stat, re, θ, st) - pm_next!(prog, stat) - end - - # return status of the optimiser for potential continuation of training - return θ, map(identity, opt_stats), st -end From 5ff60419d22a427662d4881e5525d7fdabdf0ea6 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:13:30 -0800 Subject: [PATCH 03/43] fix some compat issue and bump version --- Project.toml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 376200ca..7a55f041 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,11 @@ name = "NormalizingFlows" uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256" -version = "0.1.1" +version = "0.1.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,13 +16,12 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -NormalizingFlowsEnzymeExt = "Enzyme" NormalizingFlowsForwardDiffExt = "ForwardDiff" NormalizingFlowsReverseDiffExt = "ReverseDiff" NormalizingFlowsZygoteExt = "Zygote" @@ -30,10 +29,10 @@ NormalizingFlowsZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2, 1" Bijectors = "0.12.6, 0.13, 0.14" -DiffResults = "1" +DifferentiationInterface = "0.6" Distributions = "0.25" DocStringExtensions = "0.9" -Enzyme = "0.11, 0.12" +Mooncake = "0.4.95" ForwardDiff = "0.10.25" Optimisers = "0.2.16, 0.3" ProgressMeter = "1.0.0" @@ -44,7 +43,6 @@ Zygote = "0.6" julia = "1.6" [extras] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From 3010669f350a2c8d84da40ad5aba926079a1a1c8 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:14:06 -0800 Subject: [PATCH 04/43] update tests to new interface --- test/Project.toml | 4 ++-- test/ad.jl | 29 +++++++++++++++-------------- test/interface.jl | 6 +++--- test/objectives.jl | 7 ++++--- test/runtests.jl | 5 +++-- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index c474adda..f2a03298 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,11 +1,11 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/ad.jl b/test/ad.jl index a394d806..fc4c9b0f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,8 +1,10 @@ -@testset "AD correctness" begin - f(x) = sum(abs2, x) +@testset "DI.AD with context wrapper" begin + f(x, y, z) = sum(abs2, x .+ y .+ z) @testset "$T" for T in [Float32, Float64] x = randn(T, 10) + y = randn(T, 10) + z = randn(T, 10) chunksize = size(x, 1) @testset "$at" for at in [ @@ -10,12 +12,11 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - ADTypes.AutoEnzyme(), + ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()), ] - out = DiffResults.GradientResult(x) - NormalizingFlows.value_and_gradient!(at, f, x, out) - @test DiffResults.value(out) ≈ f(x) - @test DiffResults.gradient(out) ≈ 2x + value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z) + @test DiffResults.value(out) ≈ f(x, y, z) + @test DiffResults.gradient(out) ≈ 2 * (x .+ y .+ z) end end end @@ -25,7 +26,7 @@ end ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # not working now + ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -38,15 +39,15 @@ end sample_per_iter = 10 θ, re = Optimisers.destructure(flow) - out = DiffResults.GradientResult(θ) # check grad computation for elbo - NormalizingFlows.grad!( - Random.default_rng(), at, elbo, θ, re, out, logp, sample_per_iter + loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...) + value, grad = NormalizingFlows._value_and_gradient( + loss, at, θ, logp, randn(T, 2, sample_per_iter) ) - @test DiffResults.value(out) != nothing - @test all(DiffResults.gradient(out) .!= nothing) + @test !isnothing(value) + @test all(grad .!= nothing) end end -end \ No newline at end of file +end diff --git a/test/interface.jl b/test/interface.jl index a3540979..b630b5b6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -4,8 +4,8 @@ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), - # ADTypes.AutoEnzyme(), # doesn't work for Enzyme + ADTypes.AutoReverseDiff(), + ADTypes.AutoMooncake(; config = ADTypes.Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -44,4 +44,4 @@ @test el_trained > -1 end end -end \ No newline at end of file +end diff --git a/test/objectives.jl b/test/objectives.jl index 072286d1..4641b3cd 100644 --- a/test/objectives.jl +++ b/test/objectives.jl @@ -9,9 +9,10 @@ flow = Bijectors.transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(sqrt.(Σ))) x = randn(T, 2) + rng = Random.default_rng() @testset "elbo" begin - el = elbo(Random.default_rng(), flow, logp, 10) + el = elbo(rng, flow, logp, 10) @test abs(el) ≤ 1e-5 @test logpdf(flow, x) + el ≈ logp(x) @@ -20,8 +21,8 @@ @testset "likelihood" begin sample_trained = rand(flow, 1000) sample_untrained = rand(q₀, 1000) - llh_trained = NormalizingFlows.loglikelihood(flow, sample_trained) - llh_untrained = NormalizingFlows.loglikelihood(flow, sample_untrained) + llh_trained = NormalizingFlows.loglikelihood(rng, flow, sample_trained) + llh_untrained = NormalizingFlows.loglikelihood(rng, flow, sample_untrained) @test llh_trained > llh_untrained end diff --git a/test/runtests.jl b/test/runtests.jl index e050a645..05957180 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,10 +3,11 @@ using Distributions using Bijectors, Optimisers using LinearAlgebra using Random -using ADTypes, DiffResults +using ADTypes +import DifferentiationInterface as DI using ForwardDiff, Zygote, Enzyme, ReverseDiff using Test include("ad.jl") include("objectives.jl") -include("interface.jl") \ No newline at end of file +include("interface.jl") From f7ee84b6862f1a6bf686e979559e1dc68f87c6f9 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:45:13 -0800 Subject: [PATCH 05/43] add Moonkcake to extras --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 7a55f041..edbb2f0d 100644 --- a/Project.toml +++ b/Project.toml @@ -43,6 +43,7 @@ Zygote = "0.6" julia = "1.6" [extras] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From e68ef5f67e9ae2403fa6e8bc67081d58f242bd88 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:46:38 -0800 Subject: [PATCH 06/43] rm all ext for now --- Project.toml | 3 --- ext/NormalizingFlowsEnzymeExt.jl | 25 ------------------------ ext/NormalizingFlowsForwardDiffExt.jl | 28 --------------------------- ext/NormalizingFlowsReverseDiffExt.jl | 22 --------------------- ext/NormalizingFlowsZygoteExt.jl | 23 ---------------------- src/NormalizingFlows.jl | 3 +-- 6 files changed, 1 insertion(+), 103 deletions(-) delete mode 100644 ext/NormalizingFlowsEnzymeExt.jl delete mode 100644 ext/NormalizingFlowsForwardDiffExt.jl delete mode 100644 ext/NormalizingFlowsReverseDiffExt.jl delete mode 100644 ext/NormalizingFlowsZygoteExt.jl diff --git a/Project.toml b/Project.toml index edbb2f0d..f831e3e0 100644 --- a/Project.toml +++ b/Project.toml @@ -22,9 +22,6 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -NormalizingFlowsForwardDiffExt = "ForwardDiff" -NormalizingFlowsReverseDiffExt = "ReverseDiff" -NormalizingFlowsZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2, 1" diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl deleted file mode 100644 index 1b59cad8..00000000 --- a/ext/NormalizingFlowsEnzymeExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module NormalizingFlowsEnzymeExt - -if isdefined(Base, :get_extension) - using Enzyme - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults -else - using ..Enzyme - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults -end - -# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - y = f(θ) - DiffResults.value!(out, y) - ∇θ = DiffResults.gradient(out) - fill!(∇θ, zero(T)) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsForwardDiffExt.jl b/ext/NormalizingFlowsForwardDiffExt.jl deleted file mode 100644 index 500d54f4..00000000 --- a/ext/NormalizingFlowsForwardDiffExt.jl +++ /dev/null @@ -1,28 +0,0 @@ -module NormalizingFlowsForwardDiffExt - -if isdefined(Base, :get_extension) - using ForwardDiff - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults -else - using ..ForwardDiff - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults -end - -# extract chunk size from AutoForwardDiff -getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - chunk_size = getchunksize(ad) - config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, θ) - else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) - end - ForwardDiff.gradient!(out, f, θ, config) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsReverseDiffExt.jl b/ext/NormalizingFlowsReverseDiffExt.jl deleted file mode 100644 index 1bd39dc4..00000000 --- a/ext/NormalizingFlowsReverseDiffExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module NormalizingFlowsReverseDiffExt - -if isdefined(Base, :get_extension) - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults - using ReverseDiff -else - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults - using ..ReverseDiff -end - -# ReverseDiff without compiled tape -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out -end - -end \ No newline at end of file diff --git a/ext/NormalizingFlowsZygoteExt.jl b/ext/NormalizingFlowsZygoteExt.jl deleted file mode 100644 index 0eee943c..00000000 --- a/ext/NormalizingFlowsZygoteExt.jl +++ /dev/null @@ -1,23 +0,0 @@ -module NormalizingFlowsZygoteExt - -if isdefined(Base, :get_extension) - using NormalizingFlows - using NormalizingFlows: ADTypes, DiffResults - using Zygote -else - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DiffResults - using ..Zygote -end - -function NormalizingFlows.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} - y, back = Zygote.pullback(f, θ) - ∇θ = back(one(T)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, first(∇θ)) - return out -end - -end \ No newline at end of file diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 89e0ce3c..40999948 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -11,8 +11,7 @@ using DocStringExtensions export train_flow, elbo, loglikelihood -""" - train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) +""" train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) Train the given normalizing flow `flow` by calling `optimize`. From 9bdf1f75927035ee379e9c21761a7742c8b30320 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:50:38 -0800 Subject: [PATCH 07/43] rm enzyme test, and import mooncake for test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 05957180..8a611158 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using LinearAlgebra using Random using ADTypes import DifferentiationInterface as DI -using ForwardDiff, Zygote, Enzyme, ReverseDiff +using ForwardDiff, Zygote, ReverseDiff, Mooncake using Test include("ad.jl") From b7f9f087869fc317e9cf7bd0e3c673be545b5bd1 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 15:57:38 -0800 Subject: [PATCH 08/43] fixing compat and test with mooncake --- Project.toml | 6 +++--- test/ad.jl | 4 ++-- test/interface.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index f831e3e0..ccaa178f 100644 --- a/Project.toml +++ b/Project.toml @@ -25,18 +25,18 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.1, 0.2, 1" -Bijectors = "0.12.6, 0.13, 0.14" +Bijectors = "0.12.6, 0.13, 0.14, 0.15" DifferentiationInterface = "0.6" Distributions = "0.25" DocStringExtensions = "0.9" Mooncake = "0.4.95" ForwardDiff = "0.10.25" -Optimisers = "0.2.16, 0.3" +Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.0.0" Requires = "1" ReverseDiff = "1.14" StatsBase = "0.33, 0.34" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.6" [extras] diff --git a/test/ad.jl b/test/ad.jl index fc4c9b0f..cc15c933 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -12,7 +12,7 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()), + ADTypes.AutoMooncake(; config=Mooncake.Config()), ] value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z) @test DiffResults.value(out) ≈ f(x, y, z) @@ -26,7 +26,7 @@ end ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(false), - ADTypes.AutoMooncake(; config=ADTypes.Mooncake.Config()), + ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) diff --git a/test/interface.jl b/test/interface.jl index b630b5b6..d53b1ab2 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,7 +5,7 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), - ADTypes.AutoMooncake(; config = ADTypes.Mooncake.Config()), + ADTypes.AutoMooncake(; config = Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) From 1970b09a4ed835e711e695b1b502be08f2c58b7f Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 16:02:54 -0800 Subject: [PATCH 09/43] fixing test bug --- Project.toml | 2 +- test/ad.jl | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ccaa178f..ed5d283c 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ Requires = "1" ReverseDiff = "1.14" StatsBase = "0.33, 0.34" Zygote = "0.6, 0.7" -julia = "1.6" +julia = "1.10" [extras] Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" diff --git a/test/ad.jl b/test/ad.jl index cc15c933..6048f860 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -14,7 +14,8 @@ ADTypes.AutoReverseDiff(false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] - value, grad = NormalizingFlows._value_and_gradient(f, at, x, y, z) + prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) + value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) @test DiffResults.value(out) ≈ f(x, y, z) @test DiffResults.gradient(out) ≈ 2 * (x .+ y .+ z) end @@ -42,8 +43,9 @@ end # check grad computation for elbo loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...) + prep = NormalizingFlows._prepare_gradient(loss, at, θ, logp, randn(T, 2, sample_per_iter)) value, grad = NormalizingFlows._value_and_gradient( - loss, at, θ, logp, randn(T, 2, sample_per_iter) + loss, prep, at, θ, logp, randn(T, 2, sample_per_iter) ) @test !isnothing(value) From b0390f7bed9cb3c1e18f222be98a8878056d2ecc Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 16:21:01 -0800 Subject: [PATCH 10/43] fix _value_and_grad wrapper bug --- src/optimize.jl | 6 +++--- test/Project.toml | 1 + test/ad.jl | 11 ++++++++--- test/runtests.jl | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index a4c02370..efceaa15 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -5,17 +5,17 @@ function pm_next!(pm, stats::NamedTuple) return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) end -_wrap_in_DI_context(args...) = DifferentiationInterface.Constant.([args...]) +_wrap_in_DI_context(args) = DifferentiationInterface.Constant.([args...]) function _prepare_gradient(loss, adbackend, θ, args...) - if isempty(args...) + if isempty(args) return DifferentiationInterface.prepare_gradient(loss, adbackend, θ) end return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...) end function _value_and_gradient(loss, prep, adbackend, θ, args...) - if isempty(args...) + if isempty(args) return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ) end return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...) diff --git a/test/Project.toml b/test/Project.toml index f2a03298..6a443ccf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/ad.jl b/test/ad.jl index 6048f860..540dfa20 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,6 @@ @testset "DI.AD with context wrapper" begin f(x, y, z) = sum(abs2, x .+ y .+ z) + T = Float32 @testset "$T" for T in [Float32, Float64] x = randn(T, 10) @@ -11,9 +12,10 @@ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), + ADTypes.AutoReverseDiff(; false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] + at = ADTypes.AutoMooncake(; config=Mooncake.Config()) prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) @test DiffResults.value(out) ≈ f(x, y, z) @@ -26,7 +28,7 @@ end @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(false), + ADTypes.AutoReverseDiff(; false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] @@ -34,7 +36,10 @@ end Σ = Diagonal(4 * ones(T, 2)) target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) - + + # necessary for Zygote/mooncake to differentiate through the flow + # prevent opt q0 + @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ))) diff --git a/test/runtests.jl b/test/runtests.jl index 8a611158..8058bbfb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Bijectors, Optimisers using LinearAlgebra using Random using ADTypes -import DifferentiationInterface as DI +# import DifferentiationInterface as DI using ForwardDiff, Zygote, ReverseDiff, Mooncake using Test From 3b40c3329a2f1ee8eb69332b686cbfbdc04d4816 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 16:25:32 -0800 Subject: [PATCH 11/43] fix AutoReverseDiff argument typo --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 540dfa20..e4e07c94 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -12,7 +12,7 @@ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(; false), + ADTypes.AutoReverseDiff(; compile=false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] at = ADTypes.AutoMooncake(; config=Mooncake.Config()) @@ -28,7 +28,7 @@ end @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(; false), + ADTypes.AutoReverseDiff(; compile = false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] From 9552de146044527a54ae4031c7fc22574dd99dc1 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 17:01:48 -0800 Subject: [PATCH 12/43] minor ed --- test/ad.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index e4e07c94..04c0bf4a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -15,11 +15,11 @@ ADTypes.AutoReverseDiff(; compile=false), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] - at = ADTypes.AutoMooncake(; config=Mooncake.Config()) + # at = ADTypes.AutoMooncake(; config=Mooncake.Config()) prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) - @test DiffResults.value(out) ≈ f(x, y, z) - @test DiffResults.gradient(out) ≈ 2 * (x .+ y .+ z) + @test value ≈ f(x, y, z) + @test grad ≈ 2 * (x .+ y .+ z) end end end From 104e5dddb31d4d9c44779d9f792de5594718b403 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 17:06:50 -0800 Subject: [PATCH 13/43] minor ed --- test/Project.toml | 1 + test/runtests.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 6a443ccf..8d763c17 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" diff --git a/test/runtests.jl b/test/runtests.jl index 8058bbfb..8471efca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Bijectors, Optimisers using LinearAlgebra using Random using ADTypes +using Functors # import DifferentiationInterface as DI using ForwardDiff, Zygote, ReverseDiff, Mooncake using Test From d1ec8347d621fddc59dbef2fb8af74aa8b065ed1 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 17:24:39 -0800 Subject: [PATCH 14/43] fixing test --- src/optimize.jl | 4 ++-- test/interface.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index efceaa15..7a87a344 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,9 +66,9 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal """ function optimize( adbackend, - loss::Function, + loss, θ₀::AbstractVector{<:Real}, - reconstruct::Function, + reconstruct, args...; max_iters::Int=10000, optimiser::Optimisers.AbstractRule=Optimisers.ADAM(), diff --git a/test/interface.jl b/test/interface.jl index d53b1ab2..7611c6da 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -13,6 +13,7 @@ target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) + @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed( q₀, Bijectors.Shift(zero.(μ)) ∘ Bijectors.Scale(ones(T, 2)) @@ -21,7 +22,7 @@ sample_per_iter = 10 cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3 - flow_trained, stats, _ = train_flow( + flow_trained, stats, _, _ = train_flow( elbo, flow, logp, From deba738a6265ebc0befff46d90ced674413e8e23 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 19:16:07 -0800 Subject: [PATCH 15/43] minor ed --- test/interface.jl | 8 ++++++-- test/runtests.jl | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 7611c6da..3cf6511f 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -8,6 +8,10 @@ ADTypes.AutoMooncake(; config = Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] + # adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) + # T = Float32 + + Random.seed!(1234) μ = 10 * ones(T, 2) Σ = Diagonal(4 * ones(T, 2)) target = MvNormal(μ, Σ) @@ -28,9 +32,9 @@ logp, sample_per_iter; max_iters=5_000, - optimiser=Optimisers.ADAM(0.01 * one(T)), + optimiser=Optimisers.Adam(0.01 * one(T)), ADbackend=adtype, - show_progress=false, + show_progress=true, callback=cb, hasconverged=checkconv, ) diff --git a/test/runtests.jl b/test/runtests.jl index 8471efca..b1131002 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,6 @@ using LinearAlgebra using Random using ADTypes using Functors -# import DifferentiationInterface as DI using ForwardDiff, Zygote, ReverseDiff, Mooncake using Test From 8976307b92e19afa2a4f14ee780491ab1b87a81e Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 19:44:35 -0800 Subject: [PATCH 16/43] rm test for mooncake --- test/interface.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 3cf6511f..3d6c60fa 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,27 +5,24 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), - ADTypes.AutoMooncake(; config = Mooncake.Config()), + # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 ] @testset "$T" for T in [Float32, Float64] - # adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) - # T = Float32 - - Random.seed!(1234) μ = 10 * ones(T, 2) Σ = Diagonal(4 * ones(T, 2)) + target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed( - q₀, Bijectors.Shift(zero.(μ)) ∘ Bijectors.Scale(ones(T, 2)) + q₀, Bijectors.Shift(zeros(T, 2)) ∘ Bijectors.Scale(ones(T, 2)) ) sample_per_iter = 10 - cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) - checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3 + cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) + checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 flow_trained, stats, _, _ = train_flow( elbo, flow, @@ -34,7 +31,7 @@ max_iters=5_000, optimiser=Optimisers.Adam(0.01 * one(T)), ADbackend=adtype, - show_progress=true, + show_progress=false, callback=cb, hasconverged=checkconv, ) From cb3db4a1f32f878f4b97a97a46737ddf54901970 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 20:15:50 -0800 Subject: [PATCH 17/43] fix doc --- docs/src/api.md | 11 ++--------- docs/src/example.md | 7 +++++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f8028b91..8d386ae8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows: ```@julia using Distributions, Bijectors T= Float32 +@leaf MvNormal # to prevent params in q₀ from being optimized q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2))) ``` @@ -23,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w using NormalizingFlows sample_per_iter = 10 -flow_trained, stats, _ = train_flow( +flow_trained, stats, _ , _ = train_flow( elbo, flow, logp, @@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood ```@docs NormalizingFlows.optimize ``` - - -## Utility Functions for Taking Gradient -```@docs -NormalizingFlows.grad! -NormalizingFlows.value_and_gradient! -``` - diff --git a/docs/src/example.md b/docs/src/example.md index 346c15a0..01a9a671 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a ```julia using Bijectors, FunctionChains +using Functors function create_planar_flow(n_layers::Int, q₀) d = length(q₀) @@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀) end # create a 20-layer planar flow -flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) +@leaf MvNormal # to prevent params in q₀ from being optimized +q₀ = MvNormal(zeros(Float32, 2), I) +flow = create_planar_flow(20, q₀) flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison ``` *Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl). @@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel= ## Reference -- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning \ No newline at end of file +- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning From 1615009a5c03568d36b261d9ad499644ffdc95da Mon Sep 17 00:00:00 2001 From: Zuheng Date: Sun, 16 Feb 2025 20:16:43 -0800 Subject: [PATCH 18/43] chagne CI --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 812f9b73..91513651 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,7 @@ jobs: matrix: version: - '1' - - '1.6' + - '1.10' os: - ubuntu-latest arch: From 1166e6b94e2e6ed4c57228a8274ceb27b1bd4116 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 26 Feb 2025 13:01:05 +0000 Subject: [PATCH 19/43] update CI --- .github/workflows/CI.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 91513651..f21ba541 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,15 +1,18 @@ name: CI + on: push: branches: - main tags: ['*'] pull_request: + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -19,17 +22,17 @@ jobs: matrix: version: - '1' - - '1.10' + - 'min' os: - ubuntu-latest arch: - x64 steps: - - uses: actions/checkout@v3 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 From 906d788e45571bfb686d1c1addf847087aba0dfc Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 26 Feb 2025 13:15:23 +0000 Subject: [PATCH 20/43] streamline project toml --- Project.toml | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 18259e22..8b7cf48b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NormalizingFlows" uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256" -version = "0.1.2" +version = "0.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -15,33 +15,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -[weakdeps] -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[extensions] - [compat] ADTypes = "0.1, 0.2, 1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" DifferentiationInterface = "0.6" Distributions = "0.25" DocStringExtensions = "0.9" -Mooncake = "0.4.95" -Enzyme = "0.11, 0.12, 0.13" -ForwardDiff = "0.10.25" Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.0.0" Requires = "1" -ReverseDiff = "1.14" StatsBase = "0.33, 0.34" -Zygote = "0.6, 0.7" julia = "1.10" - -[extras] -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From 0d6302b9a8e035009ba09ca99a13a331fc24a504 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:16:54 +0000 Subject: [PATCH 21/43] Apply suggestions from code review Co-authored-by: David Widmann --- src/NormalizingFlows.jl | 5 +++-- src/optimize.jl | 13 ++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 40999948..ddb43686 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -11,7 +11,8 @@ using DocStringExtensions export train_flow, elbo, loglikelihood -""" train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) +""" + train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) Train the given normalizing flow `flow` by calling `optimize`. @@ -61,7 +62,7 @@ function train_flow( ADbackend, loss, θ_flat, - re, + re, (rng, args...)...; max_iters=max_iters, optimiser=optimiser, diff --git a/src/optimize.jl b/src/optimize.jl index 7a87a344..eeae5ecb 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -2,23 +2,23 @@ # training loop for variational objectives ####################################################### function pm_next!(pm, stats::NamedTuple) - return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) + return ProgressMeter.next!(pm; showvalues=map(tuple, keys(stats), values(stats))) end -_wrap_in_DI_context(args) = DifferentiationInterface.Constant.([args...]) +_wrap_in_DI_context(args) = map(DifferentiationInterface.Constant, args) function _prepare_gradient(loss, adbackend, θ, args...) if isempty(args) return DifferentiationInterface.prepare_gradient(loss, adbackend, θ) end - return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...) + return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end function _value_and_gradient(loss, prep, adbackend, θ, args...) if isempty(args) return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ) end - return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...) + return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end @@ -42,7 +42,6 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal - `re`: reconstruction function that maps the flattened parameters to the normalizing flow - `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant) - # Keyword Arguments - `max_iters::Int=10000`: maximum number of iterations - `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps @@ -102,9 +101,9 @@ function optimize( stat = (iteration=i, loss=ls, gradient_norm=norm(g)) # callback - if !isnothing(callback) + if callback !== nothing new_stat = callback(i, opt_stats, reconstruct, θ) - stat = !isnothing(new_stat) ? merge(stat, new_stat) : stat + stat = new_stat !== nothing ? merge(stat, new_stat) : stat end push!(opt_stats, stat) From 63321c03d3024bb458161285d37ea8fc3172affe Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 26 Feb 2025 13:37:39 +0000 Subject: [PATCH 22/43] add enzyme to tests --- Project.toml | 2 +- test/Project.toml | 1 + test/ad.jl | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8b7cf48b..d382c93e 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -ADTypes = "0.1, 0.2, 1" +ADTypes = "1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" DifferentiationInterface = "0.6" Distributions = "0.25" diff --git a/test/Project.toml b/test/Project.toml index 8d763c17..a6203532 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/ad.jl b/test/ad.jl index 04c0bf4a..9c8bd7d5 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -13,6 +13,7 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] # at = ADTypes.AutoMooncake(; config=Mooncake.Config()) @@ -29,6 +30,7 @@ end ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(; compile = false), + ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] From 7204391d8f87e69b03be61c33925335d03a4fc90 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 26 Feb 2025 13:44:59 +0000 Subject: [PATCH 23/43] add Enzyme to using list --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b1131002..069c7888 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using LinearAlgebra using Random using ADTypes using Functors -using ForwardDiff, Zygote, ReverseDiff, Mooncake +using ForwardDiff, Zygote, ReverseDiff, Enzyme, Mooncake using Test include("ad.jl") From 77b9a2ee3cb5c3a3dc38d14bdd886b52c0465529 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Mon, 3 Mar 2025 14:31:11 -0800 Subject: [PATCH 24/43] fixing enzyme readonly error by wrapping loss in Const --- src/NormalizingFlows.jl | 1 + src/optimize.jl | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index ddb43686..ee169267 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -6,6 +6,7 @@ using LinearAlgebra, Random, Distributions, StatsBase using ProgressMeter using ADTypes using DifferentiationInterface +using EnzymeCore using DocStringExtensions diff --git a/src/optimize.jl b/src/optimize.jl index eeae5ecb..f61e0b74 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -21,6 +21,21 @@ function _value_and_gradient(loss, prep, adbackend, θ, args...) return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end +# TODO: move to ext +# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 +function _prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) + if isempty(args) + return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ) + end + return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ, map(DifferentiationInterface.Constant, args)...) +end +function _value_and_gradient(loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...) + if isempty(args) + return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ) + end + return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) +end + """ optimize( @@ -82,11 +97,10 @@ function optimize( opt_stats = [] # prepare loss and autograd - θ = copy(θ₀) + θ = deepcopy(θ₀) # grad = similar(θ) prep = _prepare_gradient(loss, adbackend, θ₀, args...) - # initialise optimiser state st = Optimisers.setup(optimiser, θ) From 9a8ed04f1c4639895597dec8d74c52aa64dd6c32 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Mon, 3 Mar 2025 14:48:41 -0800 Subject: [PATCH 25/43] mv enzyme related edits to ext/ and fix tests --- Project.toml | 6 ++++++ ext/NormalizingFlowsEnzymeExt.jl | 28 ++++++++++++++++++++++++++++ src/NormalizingFlows.jl | 16 +++++++++++++++- src/optimize.jl | 15 --------------- test/ad.jl | 18 ++++++++++-------- test/interface.jl | 3 ++- test/runtests.jl | 1 + 7 files changed, 62 insertions(+), 25 deletions(-) create mode 100644 ext/NormalizingFlowsEnzymeExt.jl diff --git a/Project.toml b/Project.toml index d382c93e..a27b8298 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[extensions] +NormalizingFlowsEnzymeExt = "Enzyme" + [compat] ADTypes = "1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl new file mode 100644 index 00000000..d4b8487f --- /dev/null +++ b/ext/NormalizingFlowsEnzymeExt.jl @@ -0,0 +1,28 @@ +module NormalizingFlowsEnzymeExt + +if isdefined(Base, :get_extension) + using EnzymeCore + using NormalizingFlows + using NormalizingFlows: ADTypes, DifferentiationInterface +else + using ..EnzymeCore + using ..NormalizingFlows + using ..NormalizingFlows: ADTypes, DifferentiationInterface +end + + +# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 +function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) + if isempty(args) + return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ) + end + return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ, map(DifferentiationInterface.Constant, args)...) +end +function NormalizingFlows._value_and_gradient(loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...) + if isempty(args) + return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ) + end + return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) +end + +end diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index ee169267..73cc29a2 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -6,7 +6,6 @@ using LinearAlgebra, Random, Distributions, StatsBase using ProgressMeter using ADTypes using DifferentiationInterface -using EnzymeCore using DocStringExtensions @@ -80,4 +79,19 @@ include("optimize.jl") include("objectives.jl") +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end + +# Question: should Exts be loaded here or in train.jl? +function __init__() + @static if !isdefined(Base, :get_extension) + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include( + "../ext/NormalizingFlowsEnzymeExt.jl" + ) + end +end + + end diff --git a/src/optimize.jl b/src/optimize.jl index f61e0b74..d1ee80af 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -21,21 +21,6 @@ function _value_and_gradient(loss, prep, adbackend, θ, args...) return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end -# TODO: move to ext -# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 -function _prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) - if isempty(args) - return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ) - end - return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ, map(DifferentiationInterface.Constant, args)...) -end -function _value_and_gradient(loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...) - if isempty(args) - return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ) - end - return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) -end - """ optimize( diff --git a/test/ad.jl b/test/ad.jl index 9c8bd7d5..49608bc1 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,5 @@ @testset "DI.AD with context wrapper" begin f(x, y, z) = sum(abs2, x .+ y .+ z) - T = Float32 @testset "$T" for T in [Float32, Float64] x = randn(T, 10) @@ -25,13 +24,13 @@ end end -@testset "AD for ELBO" begin +@testset "AD for ELBO on mean-field Gaussian VI" begin @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(; compile = false), ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), - ADTypes.AutoMooncake(; config=Mooncake.Config()), + # ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -45,17 +44,20 @@ end q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ))) - sample_per_iter = 10 θ, re = Optimisers.destructure(flow) # check grad computation for elbo - loss(θ, args...) = -NormalizingFlows.elbo(re(θ), args...) - prep = NormalizingFlows._prepare_gradient(loss, at, θ, logp, randn(T, 2, sample_per_iter)) + loss(θ, rng, logp, sample_per_iter) = -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter) + + rng = Random.default_rng() + sample_per_iter = 10 + + prep = NormalizingFlows._prepare_gradient(loss, at, θ, rng, logp, sample_per_iter) value, grad = NormalizingFlows._value_and_gradient( - loss, prep, at, θ, logp, randn(T, 2, sample_per_iter) + loss, prep, at, θ, rng, logp, sample_per_iter ) - @test !isnothing(value) + @test value !== nothing @test all(grad .!= nothing) end end diff --git a/test/interface.jl b/test/interface.jl index 3d6c60fa..025bfc14 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,6 +5,7 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), + ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 ] @testset "$T" for T in [Float32, Float64] @@ -29,7 +30,7 @@ logp, sample_per_iter; max_iters=5_000, - optimiser=Optimisers.Adam(0.01 * one(T)), + optimiser=Optimisers.Adam(one(T)/100), ADbackend=adtype, show_progress=false, callback=cb, diff --git a/test/runtests.jl b/test/runtests.jl index 069c7888..a2915e66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Random using ADTypes using Functors using ForwardDiff, Zygote, ReverseDiff, Enzyme, Mooncake +import DifferentiationInterface as DI using Test include("ad.jl") From b3487b58740b62f14c7ac46446c5de796196e33f Mon Sep 17 00:00:00 2001 From: Zuheng Date: Mon, 3 Mar 2025 16:30:11 -0800 Subject: [PATCH 26/43] fixing extension loading error --- Project.toml | 7 +++- ext/NormalizingFlowsEnzymeCoreExt.jl | 39 ++++++++++++++++++ ext/NormalizingFlowsEnzymeExt.jl | 28 ------------- src/NormalizingFlows.jl | 14 +++---- src/optimize.jl | 1 - test/Project.toml | 1 + test/ad.jl | 9 +++-- test/interface.jl | 60 +++++++++++++++++++++++++++- test/runtests.jl | 2 + 9 files changed, 117 insertions(+), 44 deletions(-) create mode 100644 ext/NormalizingFlowsEnzymeCoreExt.jl delete mode 100644 ext/NormalizingFlowsEnzymeExt.jl diff --git a/Project.toml b/Project.toml index a27b8298..e23098e1 100644 --- a/Project.toml +++ b/Project.toml @@ -16,10 +16,10 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] -NormalizingFlowsEnzymeExt = "Enzyme" +NormalizingFlowsEnzymeCoreExt = "EnzymeCore" [compat] ADTypes = "1" @@ -32,3 +32,6 @@ ProgressMeter = "1.0.0" Requires = "1" StatsBase = "0.33, 0.34" julia = "1.10" + +[extras] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/NormalizingFlowsEnzymeCoreExt.jl b/ext/NormalizingFlowsEnzymeCoreExt.jl new file mode 100644 index 00000000..f2a378ea --- /dev/null +++ b/ext/NormalizingFlowsEnzymeCoreExt.jl @@ -0,0 +1,39 @@ +module NormalizingFlowsEnzymeCoreExt + +using EnzymeCore +using NormalizingFlows +using NormalizingFlows: ADTypes, DifferentiationInterface + +# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 +function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) + if isempty(args) + return DifferentiationInterface.prepare_gradient( + EnzymeCore.Const(loss), adbackend, θ + ) + end + return DifferentiationInterface.prepare_gradient( + EnzymeCore.Const(loss), + adbackend, + θ, + map(DifferentiationInterface.Constant, args)..., + ) +end + +function NormalizingFlows._value_and_gradient( + loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args... +) + if isempty(args) + return DifferentiationInterface.value_and_gradient( + EnzymeCore.Const(loss), prep, adbackend, θ + ) + end + return DifferentiationInterface.value_and_gradient( + EnzymeCore.Const(loss), + prep, + adbackend, + θ, + map(DifferentiationInterface.Constant, args)..., + ) +end + +end diff --git a/ext/NormalizingFlowsEnzymeExt.jl b/ext/NormalizingFlowsEnzymeExt.jl deleted file mode 100644 index d4b8487f..00000000 --- a/ext/NormalizingFlowsEnzymeExt.jl +++ /dev/null @@ -1,28 +0,0 @@ -module NormalizingFlowsEnzymeExt - -if isdefined(Base, :get_extension) - using EnzymeCore - using NormalizingFlows - using NormalizingFlows: ADTypes, DifferentiationInterface -else - using ..EnzymeCore - using ..NormalizingFlows - using ..NormalizingFlows: ADTypes, DifferentiationInterface -end - - -# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 -function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) - if isempty(args) - return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ) - end - return DifferentiationInterface.prepare_gradient(EnzymeCore.Const(loss), adbackend, θ, map(DifferentiationInterface.Constant, args)...) -end -function NormalizingFlows._value_and_gradient(loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...) - if isempty(args) - return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ) - end - return DifferentiationInterface.value_and_gradient(EnzymeCore.Const(loss), prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) -end - -end diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 73cc29a2..4c3f57db 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -79,19 +79,17 @@ include("optimize.jl") include("objectives.jl") -# optional dependencies -if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base +if !isdefined(Base, :get_extension) using Requires end -# Question: should Exts be loaded here or in train.jl? -function __init__() - @static if !isdefined(Base, :get_extension) - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include( - "../ext/NormalizingFlowsEnzymeExt.jl" + +@static if !isdefined(Base, :get_extension) + function __init__() + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( + joinpath(@__DIR__, "../ext/NormalizingFlowsEnzymeCoreExt.jl") ) end end - end diff --git a/src/optimize.jl b/src/optimize.jl index d1ee80af..932c03b5 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -93,7 +93,6 @@ function optimize( converged = false i = 1 while (i ≤ max_iters) && !converged - # ls, g = DifferentiationInterface.value_and_gradient!(loss, grad, prep, adbackend, θ) ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) # Save stats diff --git a/test/Project.toml b/test/Project.toml index a6203532..66f38e3a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/ad.jl b/test/ad.jl index 49608bc1..d6e7b546 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -15,7 +15,6 @@ ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] - # at = ADTypes.AutoMooncake(; config=Mooncake.Config()) prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) value, grad = NormalizingFlows._value_and_gradient(f, prep, at, x, y, z) @test value ≈ f(x, y, z) @@ -39,11 +38,13 @@ end logp(z) = logpdf(target, z) # necessary for Zygote/mooncake to differentiate through the flow - # prevent opt q0 + # prevent updating params of q0 @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) - flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ))) - + flow = Bijectors.transformed( + q₀, Bijectors.Shift(zeros(T, 2)) ∘ Bijectors.Scale(ones(T, 2)) + ) + θ, re = Optimisers.destructure(flow) # check grad computation for elbo diff --git a/test/interface.jl b/test/interface.jl index 025bfc14..2b860795 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,4 +1,4 @@ -@testset "learining 2d Gaussian" begin +@testset "testing mean-field Gaussian VI" begin chunksize = 4 @testset "$adtype" for adtype in [ ADTypes.AutoZygote(), @@ -48,3 +48,61 @@ end end end + +# function create_planar_flow(n_layers::Int, q₀, T) +# d = length(q₀) +# if T == Float32 +# Ls = reduce(∘, [f32(PlanarLayer(d)) for _ in 1:n_layers]) +# else +# Ls = reduce(∘, [PlanarLayer(d) for _ in 1:n_layers]) +# end +# return Bijectors.transformed(q₀, Ls) +# end + +# @testset "testing planar flow" begin +# chunksize = 4 +# @testset "$adtype" for adtype in [ +# ADTypes.AutoZygote(), +# ADTypes.AutoForwardDiff(; chunksize=chunksize), +# ADTypes.AutoForwardDiff(), +# ADTypes.AutoReverseDiff(), +# ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), +# # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 +# ] +# @testset "$T" for T in [Float32, Float64] +# μ = 10 * ones(T, 2) +# Σ = Diagonal(4 * ones(T, 2)) + +# target = MvNormal(μ, Σ) +# logp(z) = logpdf(target, z) + +# @leaf MvNormal +# q₀ = MvNormal(zeros(T, 2), ones(T, 2)) +# nlayers = 10 +# flow = create_planar_flow(nlayers, q₀, T) + +# sample_per_iter = 10 +# cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +# checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +# flow_trained, stats, _, _ = train_flow( +# elbo, +# flow, +# logp, +# sample_per_iter; +# max_iters=10_000, +# optimiser=Optimisers.Adam(one(T)/100), +# ADbackend=adtype, +# show_progress=false, +# callback=cb, +# hasconverged=checkconv, +# ) +# θ, re = Optimisers.destructure(flow_trained) + +# el_untrained = elbo(flow, logp, 1000) +# el_trained = elbo(flow_trained, logp, 1000) + +# @test el_trained > el_untrained +# @test el_trained > -1 +# end +# end +# end diff --git a/test/runtests.jl b/test/runtests.jl index a2915e66..33a98085 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,9 @@ using Random using ADTypes using Functors using ForwardDiff, Zygote, ReverseDiff, Enzyme, Mooncake +using Flux: f32 import DifferentiationInterface as DI + using Test include("ad.jl") From 45756e691c0e09b21d9d57a9e9d29625f7909ac4 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Tue, 4 Mar 2025 09:51:08 -0800 Subject: [PATCH 27/43] Update Project.toml Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e23098e1..e45fc630 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ NormalizingFlowsEnzymeCoreExt = "EnzymeCore" [compat] ADTypes = "1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" -DifferentiationInterface = "0.6" +DifferentiationInterface = "0.6.42" Distributions = "0.25" DocStringExtensions = "0.9" Optimisers = "0.2.16, 0.3, 0.4" From dbe725ca6c69960c078c86435a7daa9e41739bc7 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Tue, 4 Mar 2025 09:56:59 -0800 Subject: [PATCH 28/43] remove Requires Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- src/NormalizingFlows.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 4c3f57db..a0075af7 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -79,9 +79,6 @@ include("optimize.jl") include("objectives.jl") -if !isdefined(Base, :get_extension) - using Requires -end @static if !isdefined(Base, :get_extension) From da8593c02a8ff5ca1d9ff726777cc440c1de66e0 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Tue, 4 Mar 2025 09:57:38 -0800 Subject: [PATCH 29/43] remove explit load ext Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- src/NormalizingFlows.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index a0075af7..4e38dece 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -81,12 +81,5 @@ include("objectives.jl") -@static if !isdefined(Base, :get_extension) - function __init__() - @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( - joinpath(@__DIR__, "../ext/NormalizingFlowsEnzymeCoreExt.jl") - ) - end -end end From da6699683ee0b617b080abbdde792367ad3cdbbb Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Tue, 4 Mar 2025 09:58:20 -0800 Subject: [PATCH 30/43] Update src/objectives/loglikelihood.jl Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- src/objectives/loglikelihood.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/loglikelihood.jl b/src/objectives/loglikelihood.jl index 861cd1bc..6ba24a26 100644 --- a/src/objectives/loglikelihood.jl +++ b/src/objectives/loglikelihood.jl @@ -16,7 +16,7 @@ the target distribution p. """ function loglikelihood( - rng::AbstractRNG, # empty argument + ::AbstractRNG, # empty argument flow::Bijectors.UnivariateTransformed, # variational distribution to be trained xs::AbstractVector, # sample batch from target dist p ) From 3e65cde201354e9b71d9ea2b167433553ed0be3a Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:39:44 -0800 Subject: [PATCH 31/43] make ext dep explicit --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e45fc630..5cf9dcc9 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] -NormalizingFlowsEnzymeCoreExt = "EnzymeCore" +NormalizingFlowsEnzymeCoreExt = ["EnzymeCore", "ADTypes", "DifferentiationInterface"] [compat] ADTypes = "1" From eeb9a92c6837d5a3e7121ba096f91837558a1ef0 Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:40:26 -0800 Subject: [PATCH 32/43] rm empty argument specialization for _prep_grad and _value_grad --- ext/NormalizingFlowsEnzymeCoreExt.jl | 10 ---------- src/optimize.jl | 6 ------ 2 files changed, 16 deletions(-) diff --git a/ext/NormalizingFlowsEnzymeCoreExt.jl b/ext/NormalizingFlowsEnzymeCoreExt.jl index f2a378ea..ee7a36ae 100644 --- a/ext/NormalizingFlowsEnzymeCoreExt.jl +++ b/ext/NormalizingFlowsEnzymeCoreExt.jl @@ -6,11 +6,6 @@ using NormalizingFlows: ADTypes, DifferentiationInterface # deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) - if isempty(args) - return DifferentiationInterface.prepare_gradient( - EnzymeCore.Const(loss), adbackend, θ - ) - end return DifferentiationInterface.prepare_gradient( EnzymeCore.Const(loss), adbackend, @@ -22,11 +17,6 @@ end function NormalizingFlows._value_and_gradient( loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args... ) - if isempty(args) - return DifferentiationInterface.value_and_gradient( - EnzymeCore.Const(loss), prep, adbackend, θ - ) - end return DifferentiationInterface.value_and_gradient( EnzymeCore.Const(loss), prep, diff --git a/src/optimize.jl b/src/optimize.jl index 932c03b5..bb284f62 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -8,16 +8,10 @@ end _wrap_in_DI_context(args) = map(DifferentiationInterface.Constant, args) function _prepare_gradient(loss, adbackend, θ, args...) - if isempty(args) - return DifferentiationInterface.prepare_gradient(loss, adbackend, θ) - end return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end function _value_and_gradient(loss, prep, adbackend, θ, args...) - if isempty(args) - return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ) - end return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) end From 4b97585057d95c63c7f8371b90421f7f976bb8ca Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:41:05 -0800 Subject: [PATCH 33/43] signal empty rng arg --- src/objectives.jl | 2 +- src/objectives/loglikelihood.jl | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/objectives.jl b/src/objectives.jl index e5df463b..1d7ac5a2 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -1,2 +1,2 @@ include("objectives/elbo.jl") -include("objectives/loglikelihood.jl") # not tested +include("objectives/loglikelihood.jl") # not fully tested diff --git a/src/objectives/loglikelihood.jl b/src/objectives/loglikelihood.jl index 6ba24a26..ab5eb961 100644 --- a/src/objectives/loglikelihood.jl +++ b/src/objectives/loglikelihood.jl @@ -16,7 +16,7 @@ the target distribution p. """ function loglikelihood( - ::AbstractRNG, # empty argument + ::AbstractRNG, # empty argument flow::Bijectors.UnivariateTransformed, # variational distribution to be trained xs::AbstractVector, # sample batch from target dist p ) @@ -24,10 +24,20 @@ function loglikelihood( end function loglikelihood( - rng::AbstractRNG, # empty argument + ::AbstractRNG, # empty argument flow::Bijectors.MultivariateTransformed, # variational distribution to be trained xs::AbstractMatrix, # sample batch from target dist p ) llhs = map(x -> logpdf(flow, x), eachcol(xs)) return mean(llhs) end + +## TODO:will need to implement the version that takes a dataloader +# function loglikelihood( +# rng::AbstractRNG, +# flow::Bijectors.TransformedDistribution, +# dataloader +# ) +# xs = dataloader(rng) +# return loglikelihood(rng, flow, collect(dataloader)) +# end From 907526607ab30fc1f1331f33548bfc84a0645c11 Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:41:27 -0800 Subject: [PATCH 34/43] drop Requires --- src/NormalizingFlows.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 4e38dece..d4179366 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -54,7 +54,7 @@ function train_flow( # use FunctionChains instead of simple compositions to construct the flow when many flow layers are involved # otherwise the compilation time for destructure will be too long θ_flat, re = Optimisers.destructure(flow) - + loss(θ, rng, args...) = -vo(rng, re(θ), args...) # Normalizing flow training loop @@ -73,13 +73,7 @@ function train_flow( return flow_trained, opt_stats, st, time_elapsed end - - include("optimize.jl") include("objectives.jl") - - - - end From eade7a39fc45b6e65b3ad55626364eaf9b705c26 Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:41:50 -0800 Subject: [PATCH 35/43] drop Requires --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5cf9dcc9..bb3acd60 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] @@ -29,7 +28,6 @@ Distributions = "0.25" DocStringExtensions = "0.9" Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.0.0" -Requires = "1" StatsBase = "0.33, 0.34" julia = "1.10" From 91202ff6fc69700e9a239d4b9e4926a6e46c09c7 Mon Sep 17 00:00:00 2001 From: zuhengxu Date: Tue, 4 Mar 2025 10:42:03 -0800 Subject: [PATCH 36/43] update test to include mooncake --- test/Project.toml | 3 +++ test/ad.jl | 2 +- test/interface.jl | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 66f38e3a..be5ffa0e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,3 +15,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Mooncake = "0.4.101" diff --git a/test/ad.jl b/test/ad.jl index d6e7b546..ca6fa4ae 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -29,7 +29,7 @@ end ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(; compile = false), ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), - # ADTypes.AutoMooncake(; config=Mooncake.Config()), + ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) diff --git a/test/interface.jl b/test/interface.jl index 2b860795..1eab5f58 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -6,7 +6,7 @@ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), - # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 + ADTypes.AutoMooncake(; config = Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] μ = 10 * ones(T, 2) @@ -67,7 +67,7 @@ end # ADTypes.AutoForwardDiff(), # ADTypes.AutoReverseDiff(), # ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), -# # ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 +# ADTypes.AutoMooncake(; config = Mooncake.Config()), # somehow Mooncake does not work with Float64 # ] # @testset "$T" for T in [Float32, Float64] # μ = 10 * ones(T, 2) From dcee3c024aed1cc03a49dad5139b9ec66b9986db Mon Sep 17 00:00:00 2001 From: Zuheng Date: Tue, 4 Mar 2025 16:57:13 -0800 Subject: [PATCH 37/43] rm unnecessary EnzymeCoreExt --- Project.toml | 9 --------- ext/NormalizingFlowsEnzymeCoreExt.jl | 29 ---------------------------- src/NormalizingFlows.jl | 2 +- src/optimize.jl | 8 +++----- test/ad.jl | 26 +++++++++++++++++-------- test/interface.jl | 5 ++++- 6 files changed, 26 insertions(+), 53 deletions(-) delete mode 100644 ext/NormalizingFlowsEnzymeCoreExt.jl diff --git a/Project.toml b/Project.toml index bb3acd60..0b5ec051 100644 --- a/Project.toml +++ b/Project.toml @@ -14,12 +14,6 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -[weakdeps] -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[extensions] -NormalizingFlowsEnzymeCoreExt = ["EnzymeCore", "ADTypes", "DifferentiationInterface"] - [compat] ADTypes = "1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" @@ -30,6 +24,3 @@ Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.0.0" StatsBase = "0.33, 0.34" julia = "1.10" - -[extras] -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/NormalizingFlowsEnzymeCoreExt.jl b/ext/NormalizingFlowsEnzymeCoreExt.jl deleted file mode 100644 index ee7a36ae..00000000 --- a/ext/NormalizingFlowsEnzymeCoreExt.jl +++ /dev/null @@ -1,29 +0,0 @@ -module NormalizingFlowsEnzymeCoreExt - -using EnzymeCore -using NormalizingFlows -using NormalizingFlows: ADTypes, DifferentiationInterface - -# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012 -function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...) - return DifferentiationInterface.prepare_gradient( - EnzymeCore.Const(loss), - adbackend, - θ, - map(DifferentiationInterface.Constant, args)..., - ) -end - -function NormalizingFlows._value_and_gradient( - loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args... -) - return DifferentiationInterface.value_and_gradient( - EnzymeCore.Const(loss), - prep, - adbackend, - θ, - map(DifferentiationInterface.Constant, args)..., - ) -end - -end diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index d4179366..d5f487db 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -5,7 +5,7 @@ using Optimisers using LinearAlgebra, Random, Distributions, StatsBase using ProgressMeter using ADTypes -using DifferentiationInterface +import DifferentiationInterface as DI using DocStringExtensions diff --git a/src/optimize.jl b/src/optimize.jl index bb284f62..f91bf4a5 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -5,14 +5,12 @@ function pm_next!(pm, stats::NamedTuple) return ProgressMeter.next!(pm; showvalues=map(tuple, keys(stats), values(stats))) end -_wrap_in_DI_context(args) = map(DifferentiationInterface.Constant, args) - function _prepare_gradient(loss, adbackend, θ, args...) - return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, map(DifferentiationInterface.Constant, args)...) + return DI.prepare_gradient(loss, adbackend, θ, map(DI.Constant, args)...) end function _value_and_gradient(loss, prep, adbackend, θ, args...) - return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...) + return DI.value_and_gradient(loss, prep, adbackend, θ, map(DI.Constant, args)...) end @@ -34,7 +32,7 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal - `loss`: a general loss function θ -> loss(θ, args...) returning a scalar loss value that will be minimised - `θ₀::AbstractVector{T}`: initial parameters for the loss function (in the context of normalizing flows, it will be the flattened flow parameters) - `re`: reconstruction function that maps the flattened parameters to the normalizing flow -- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant) +- `args...`: additional arguments for `loss` (will be set as DI.Constant) # Keyword Arguments - `max_iters::Int=10000`: maximum number of iterations diff --git a/test/ad.jl b/test/ad.jl index ca6fa4ae..725b3547 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -12,7 +12,10 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(; compile=false), - ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] prep = NormalizingFlows._prepare_gradient(f, at, x, y, z) @@ -27,8 +30,11 @@ end @testset "$at" for at in [ ADTypes.AutoZygote(), ADTypes.AutoForwardDiff(), - ADTypes.AutoReverseDiff(; compile = false), - ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), ADTypes.AutoMooncake(; config=Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] @@ -36,24 +42,28 @@ end Σ = Diagonal(4 * ones(T, 2)) target = MvNormal(μ, Σ) logp(z) = logpdf(target, z) - + # necessary for Zygote/mooncake to differentiate through the flow # prevent updating params of q0 - @leaf MvNormal + @leaf MvNormal q₀ = MvNormal(zeros(T, 2), ones(T, 2)) flow = Bijectors.transformed( q₀, Bijectors.Shift(zeros(T, 2)) ∘ Bijectors.Scale(ones(T, 2)) ) - + θ, re = Optimisers.destructure(flow) # check grad computation for elbo - loss(θ, rng, logp, sample_per_iter) = -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter) + function loss(θ, rng, logp, sample_per_iter) + return -NormalizingFlows.elbo(rng, re(θ), logp, sample_per_iter) + end rng = Random.default_rng() sample_per_iter = 10 - prep = NormalizingFlows._prepare_gradient(loss, at, θ, rng, logp, sample_per_iter) + prep = NormalizingFlows._prepare_gradient( + loss, at, θ, rng, logp, sample_per_iter + ) value, grad = NormalizingFlows._value_and_gradient( loss, prep, at, θ, rng, logp, sample_per_iter ) diff --git a/test/interface.jl b/test/interface.jl index 1eab5f58..8c0316cb 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,7 +5,10 @@ ADTypes.AutoForwardDiff(; chunksize=chunksize), ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), - ADTypes.AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), ADTypes.AutoMooncake(; config = Mooncake.Config()), ] @testset "$T" for T in [Float32, Float64] From edf2f123a5b9ddd54af044989fe3156b61e08924 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Tue, 4 Mar 2025 17:18:53 -0800 Subject: [PATCH 38/43] minor update of readme --- README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 918449ff..fbef5a76 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) -**Last updated: 2023-Aug-23** +**Last updated: 2025-Mar-04** A normalizing flow library for Julia. @@ -21,7 +21,7 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor To install the package, run the following command in the Julia REPL: ```julia ] # enter Pkg mode -(@v1.9) pkg> add git@github.com:TuringLang/NormalizingFlows.jl.git +(@v1.11) pkg> add NormalizingFlows ``` Then simply run the following command to use the package: ```julia @@ -29,8 +29,8 @@ using NormalizingFlows ``` ## Quick recap of normalizing flows -Normalizing flows transform a simple reference distribution $q_0$ (sometimes known as base distribution) to -a complex distribution $q$ using invertible functions. +Normalizing flows transform a simple reference distribution $q_0$ (sometimes referred to as the base distribution) +to a complex distribution $q$ using invertible functions. In more details, given the base distribution, usually a standard Gaussian distribution, i.e., $q_0 = \mathcal{N}(0, I)$, we apply a series of parameterized invertible transformations (called flow layers), $T_{1, \theta_1}, \cdots, T_{N, \theta_k}$, yielding that @@ -56,7 +56,7 @@ Given the feasibility of i.i.d. sampling and density evaluation, normalizing flo \text{Reverse KL:}\quad &\arg\min _{\theta} \mathbb{E}_{q_{\theta}}\left[\log q_{\theta}(Z)-\log p(Z)\right] \\ &= \arg\min _{\theta} \mathbb{E}_{q_0}\left[\log \frac{q_\theta(T_N\circ \cdots \circ T_1(Z_0))}{p(T_N\circ \cdots \circ T_1(Z_0))}\right] \\ -&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(F_n \circ \cdots \circ F_1(X)\right)\right] +&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(T_n \circ \cdots \circ T_1(X)\right)\right] \end{aligned} ``` and @@ -76,10 +76,12 @@ normalizing constant. In contrast, forward KL minimization is typically used for **generative modeling**, where one wants to learn the underlying distribution of some data. -## Current status and TODOs +## Current status and to-dos - [x] general interface development - [x] documentation +- [ ] integrating [Lux.jl](https://lux.csail.mit.edu/stable/tutorials/intermediate/7_RealNVP) and [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl). +This could potentially solve the GPU compatibility issue as well. - [ ] including more NF examples/Tutorials - WIP: [PR#11](https://github.com/TuringLang/NormalizingFlows.jl/pull/11) - [ ] GPU compatibility From 9fc328d900e42cfccf189ac2d6cb3d9a52a7d3bc Mon Sep 17 00:00:00 2001 From: Zuheng Date: Tue, 4 Mar 2025 17:22:02 -0800 Subject: [PATCH 39/43] typo fix in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fbef5a76..46ac6262 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Given the feasibility of i.i.d. sampling and density evaluation, normalizing flo \text{Reverse KL:}\quad &\arg\min _{\theta} \mathbb{E}_{q_{\theta}}\left[\log q_{\theta}(Z)-\log p(Z)\right] \\ &= \arg\min _{\theta} \mathbb{E}_{q_0}\left[\log \frac{q_\theta(T_N\circ \cdots \circ T_1(Z_0))}{p(T_N\circ \cdots \circ T_1(Z_0))}\right] \\ -&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(T_n \circ \cdots \circ T_1(X)\right)\right] +&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(Z_0)+\sum_{n=1}^N \log J_n\left(T_n \circ \cdots \circ T_1(Z_0)\right)\right] \end{aligned} ``` and From 41d10ef71832c9c3cc81d5c9e2236b41520008f1 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Wed, 5 Mar 2025 08:39:50 -0800 Subject: [PATCH 40/43] Update src/NormalizingFlows.jl Co-authored-by: David Widmann --- src/NormalizingFlows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index d5f487db..3e188e15 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -11,7 +11,7 @@ using DocStringExtensions export train_flow, elbo, loglikelihood -""" +""" train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) Train the given normalizing flow `flow` by calling `optimize`. From db4872b08c6aa63b60c5dcc6932a795b13c1d553 Mon Sep 17 00:00:00 2001 From: David Xu <42751767+zuhengxu@users.noreply.github.com> Date: Wed, 5 Mar 2025 08:40:33 -0800 Subject: [PATCH 41/43] Update src/NormalizingFlows.jl Co-authored-by: David Widmann --- src/NormalizingFlows.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 3e188e15..3323e472 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -63,7 +63,8 @@ function train_flow( loss, θ_flat, re, - (rng, args...)...; + rng, + args...; max_iters=max_iters, optimiser=optimiser, kwargs..., From 0fe536f7958b45c006f0714c9e4491275e6f4a40 Mon Sep 17 00:00:00 2001 From: Zuheng Date: Wed, 5 Mar 2025 08:51:45 -0800 Subject: [PATCH 42/43] rm time_elapsed from train_flow --- src/NormalizingFlows.jl | 12 ++++++--- src/optimize.jl | 59 +++++++++++++++++++---------------------- test/interface.jl | 2 +- 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 3323e472..709f7586 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -28,7 +28,13 @@ Train the given normalizing flow `flow` by calling `optimize`. - `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps - `ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()`: automatic differentiation backend, currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`. + `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, + `ADTypes.AutoMooncake()` and + `ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + )`. + If user wants to use `AutoEnzyme`, please make sure to include the `set_runtime_activity` and `function_annotation` as shown above. - `kwargs...`: additional keyword arguments for `optimize` (See [`optimize`](@ref) for details) # Returns @@ -58,7 +64,7 @@ function train_flow( loss(θ, rng, args...) = -vo(rng, re(θ), args...) # Normalizing flow training loop - θ_flat_trained, opt_stats, st, time_elapsed = optimize( + θ_flat_trained, opt_stats, st = optimize( ADbackend, loss, θ_flat, @@ -71,7 +77,7 @@ function train_flow( ) flow_trained = re(θ_flat_trained) - return flow_trained, opt_stats, st, time_elapsed + return flow_trained, opt_stats, st end include("optimize.jl") diff --git a/src/optimize.jl b/src/optimize.jl index f91bf4a5..b4adad91 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -13,7 +13,6 @@ function _value_and_gradient(loss, prep, adbackend, θ, args...) return DI.value_and_gradient(loss, prep, adbackend, θ, map(DI.Constant, args)...) end - """ optimize( ad::ADTypes.AbstractADType, @@ -58,7 +57,7 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal function optimize( adbackend, loss, - θ₀::AbstractVector{<:Real}, + θ₀::AbstractVector{<:Real}, reconstruct, args...; max_iters::Int=10000, @@ -70,42 +69,40 @@ function optimize( max_iters; desc="Training", barlen=31, showspeed=true, enabled=show_progress ), ) - time_elapsed = @elapsed begin - opt_stats = [] + opt_stats = [] - # prepare loss and autograd - θ = deepcopy(θ₀) - # grad = similar(θ) - prep = _prepare_gradient(loss, adbackend, θ₀, args...) + # prepare loss and autograd + θ = deepcopy(θ₀) + # grad = similar(θ) + prep = _prepare_gradient(loss, adbackend, θ₀, args...) - # initialise optimiser state - st = Optimisers.setup(optimiser, θ) + # initialise optimiser state + st = Optimisers.setup(optimiser, θ) - # general `hasconverged(...)` approach to allow early termination. - converged = false - i = 1 - while (i ≤ max_iters) && !converged - ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) + # general `hasconverged(...)` approach to allow early termination. + converged = false + i = 1 + while (i ≤ max_iters) && !converged + ls, g = _value_and_gradient(loss, prep, adbackend, θ, args...) - # Save stats - stat = (iteration=i, loss=ls, gradient_norm=norm(g)) + # Save stats + stat = (iteration=i, loss=ls, gradient_norm=norm(g)) - # callback - if callback !== nothing - new_stat = callback(i, opt_stats, reconstruct, θ) - stat = new_stat !== nothing ? merge(stat, new_stat) : stat - end - push!(opt_stats, stat) + # callback + if callback !== nothing + new_stat = callback(i, opt_stats, reconstruct, θ) + stat = new_stat !== nothing ? merge(stat, new_stat) : stat + end + push!(opt_stats, stat) - # update optimiser state and parameters - st, θ = Optimisers.update!(st, θ, g) + # update optimiser state and parameters + st, θ = Optimisers.update!(st, θ, g) - # check convergence - i += 1 - converged = hasconverged(i, stat, reconstruct, θ, st) - pm_next!(prog, stat) - end + # check convergence + i += 1 + converged = hasconverged(i, stat, reconstruct, θ, st) + pm_next!(prog, stat) end # return status of the optimiser for potential continuation of training - return θ, map(identity, opt_stats), st, time_elapsed + return θ, map(identity, opt_stats), st end diff --git a/test/interface.jl b/test/interface.jl index 8c0316cb..947d4f37 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -27,7 +27,7 @@ sample_per_iter = 10 cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 - flow_trained, stats, _, _ = train_flow( + flow_trained, stats, _ = train_flow( elbo, flow, logp, From fc935ce7004f880cf97b00b5642f4050bc15ef44 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 5 Mar 2025 20:52:47 +0000 Subject: [PATCH 43/43] Update docs/src/api.md Co-authored-by: David Widmann --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 8d386ae8..eb128863 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -24,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w using NormalizingFlows sample_per_iter = 10 -flow_trained, stats, _ , _ = train_flow( +flow_trained, stats, _ = train_flow( elbo, flow, logp,