diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 6f51a9f0..20f1c345 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -279,7 +279,6 @@ include("optimize.jl") ## Parameter Space SGD include("algorithms/paramspacesgd/abstractobjective.jl") -include("algorithms/paramspacesgd/paramspacesgd.jl") export ParamSpaceSGD @@ -319,6 +318,7 @@ export RepGradELBO, SubsampledObjective include("algorithms/paramspacesgd/constructors.jl") +include("algorithms/paramspacesgd/paramspacesgd.jl") export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI diff --git a/src/algorithms/paramspacesgd/constructors.jl b/src/algorithms/paramspacesgd/constructors.jl index 2ec0ae41..d3f7bef6 100644 --- a/src/algorithms/paramspacesgd/constructors.jl +++ b/src/algorithms/paramspacesgd/constructors.jl @@ -18,6 +18,22 @@ KL divergence minimization by running stochastic gradient descent with the repar - `operator::AbstractOperator`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`) - `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`) +# Output +- `q_averaged`: The variational approximation formed by the averaged SGD iterates. + +# Callback +The callback function `callback` has a signature of + + callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) + +The arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. +- `params`: Current variational parameters. +- `averaged_params`: Variational parameters averaged according to the averaging strategy. +- `gradient`: The estimated (possibly stochastic) gradient. + # Requirements - The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`. - The variational approximation ``q_{\\lambda}`` implements `rand`. @@ -25,6 +41,30 @@ KL divergence minimization by running stochastic gradient descent with the repar - The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend. - Additonal requirements on `q` may apply depending on the choice of `entropy`. """ +struct KLMinRepGradDescent{ + Obj<:Union{<:RepGradELBO,<:SubsampledObjective}, + AD<:ADTypes.AbstractADType, + Opt<:Optimisers.AbstractRule, + Avg<:AbstractAverager, + Op<:AbstractOperator, +} <: AbstractVariationalAlgorithm + objective::Obj + adtype::AD + optimizer::Opt + averager::Avg + operator::Op +end + +struct KLMinRepGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} + prob::P + q::Q + iteration::Int + grad_buf::GradBuf + opt_st::OptSt + obj_st::ObjSt + avg_st::AvgSt +end + function KLMinRepGradDescent( adtype::ADTypes.AbstractADType; entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(), @@ -39,7 +79,11 @@ function KLMinRepGradDescent( else SubsampledObjective(RepGradELBO(n_samples; entropy=entropy), subsampling) end - return ParamSpaceSGD(objective, adtype, optimizer, averager, operator) + return KLMinRepGradDescent{ + typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator) + }( + objective, adtype, optimizer, averager, operator + ) end const ADVI = KLMinRepGradDescent @@ -63,12 +107,52 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed. - `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`) - `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`) +# Output +- `q_averaged`: The variational approximation formed by the averaged SGD iterates. + +# Callback +The callback function `callback` has a signature of + + callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) + +The arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. +- `params`: Current variational parameters. +- `averaged_params`: Variational parameters averaged according to the averaging strategy. +- `gradient`: The estimated (possibly stochastic) gradient. + # Requirements - The variational family is `MvLocationScale`. - The target distribution and the variational approximation have the same support. - The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend. - Additonal requirements on `q` may apply depending on the choice of `entropy_zerograd`. """ +struct KLMinRepGradProxDescent{ + Obj<:Union{<:RepGradELBO,<:SubsampledObjective}, + AD<:ADTypes.AbstractADType, + Opt<:Optimisers.AbstractRule, + Avg<:AbstractAverager, + Op<:ProximalLocationScaleEntropy, +} <: AbstractVariationalAlgorithm + objective::Obj + adtype::AD + optimizer::Opt + averager::Avg + operator::Op +end + +struct KLMinRepGradProxDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} + prob::P + q::Q + iteration::Int + grad_buf::GradBuf + opt_st::OptSt + obj_st::ObjSt + avg_st::AvgSt +end + function KLMinRepGradProxDescent( adtype::ADTypes.AbstractADType; entropy_zerograd::Union{ @@ -85,7 +169,11 @@ function KLMinRepGradProxDescent( else SubsampledObjective(RepGradELBO(n_samples; entropy=entropy_zerograd), subsampling) end - return ParamSpaceSGD(objective, adtype, optimizer, averager, operator) + return KLMinRepGradProxDescent{ + typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator) + }( + objective, adtype, optimizer, averager, operator + ) end """ @@ -106,15 +194,55 @@ KL divergence minimization by running stochastic gradient descent with the score - `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`) - `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`) +# Output +- `q_averaged`: The variational approximation formed by the averaged SGD iterates. + +# Callback +The callback function `callback` has a signature of + + callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) + +The arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. +- `params`: Current variational parameters. +- `averaged_params`: Variational parameters averaged according to the averaging strategy. +- `gradient`: The estimated (possibly stochastic) gradient. + # Requirements - The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`. - The variational approximation ``q_{\\lambda}`` implements `rand`. - The variational approximation ``q_{\\lambda}`` implements `logpdf(q, x)`, which should also be differentiable with respect to `x`. - The target distribution and the variational approximation have the same support. """ +struct KLMinScoreGradDescent{ + Obj<:Union{<:ScoreGradELBO,<:SubsampledObjective}, + AD<:ADTypes.AbstractADType, + Opt<:Optimisers.AbstractRule, + Avg<:AbstractAverager, + Op<:AbstractOperator, +} <: AbstractVariationalAlgorithm + objective::Obj + adtype::AD + optimizer::Opt + averager::Avg + operator::Op +end + +struct KLMinScoreGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} + prob::P + q::Q + iteration::Int + grad_buf::GradBuf + opt_st::OptSt + obj_st::ObjSt + avg_st::AvgSt +end + function KLMinScoreGradDescent( adtype::ADTypes.AbstractADType; - optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(), + optimizer::Optimisers.AbstractRule=DoWG(), n_samples::Int=1, averager::AbstractAverager=PolynomialAveraging(), operator::AbstractOperator=IdentityOperator(), @@ -125,7 +253,11 @@ function KLMinScoreGradDescent( else SubsampledObjective(ScoreGradELBO(n_samples), subsampling) end - return ParamSpaceSGD(objective, adtype, optimizer, averager, operator) + return KLMinScoreGradDescent{ + typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator) + }( + objective, adtype, optimizer, averager, operator + ) end const BBVI = KLMinScoreGradDescent diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 92bbb0e5..17b3e5e7 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -1,68 +1,7 @@ -""" - ParamSpaceSGD( - objective::AbstractVariationalObjective, - adtype::ADTypes.AbstractADType, - optimizer::Optimisers.AbstractRule, - averager::AbstractAverager, - operator::AbstractOperator, - ) - -This algorithm applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters. - -The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. -This requires the variational approximation to be marked as a functor through `Functors.@functor`. - -!!! note - Different objective may impose different requirements on `adtype`, variational family, `optimizer`, and `operator`. It is therefore important to check the documentation corresponding to each specific objective. Essentially, each objective should be thought as forming its own unique algorithm. - -# Arguments -- `objective`: Variational Objective. -- `adtype`: Automatic differentiation backend. -- `optimizer`: Optimizer used for inference. -- `averager` : Parameter averaging strategy. -- `operator` : Operator applied to the parameters after each optimization step. - -# Output -- `q_averaged`: The variational approximation formed from the averaged SGD iterates. - -# Callback -The callback function `callback` has a signature of - - callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) - -The arguments are as follows: -- `rng`: Random number generator internally used by the algorithm. -- `iteration`: The index of the current iteration. -- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. -- `params`: Current variational parameters. -- `averaged_params`: Variational parameters averaged according to the averaging strategy. -- `gradient`: The estimated (possibly stochastic) gradient. - -""" -struct ParamSpaceSGD{ - Obj<:AbstractVariationalObjective, - AD<:ADTypes.AbstractADType, - Opt<:Optimisers.AbstractRule, - Avg<:AbstractAverager, - Op<:AbstractOperator, -} <: AbstractVariationalAlgorithm - objective::Obj - adtype::AD - optimizer::Opt - averager::Avg - operator::Op -end - -struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} - prob::P - q::Q - iteration::Int - grad_buf::GradBuf - opt_st::OptSt - obj_st::ObjSt - avg_st::AvgSt -end +const ParamSpaceSGD = Union{ + <:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent +} function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob) (; adtype, optimizer, averager, objective, operator) = alg @@ -76,7 +15,17 @@ function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob) obj_st = init(rng, objective, adtype, q_init, prob, params, re) avg_st = init(averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st) + if alg isa KLMinRepGradDescent + return KLMinRepGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st) + elseif alg isa KLMinRepGradProxDescent + return KLMinRepGradProxDescentState( + prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st + ) + elseif alg isa KLMinScoreGradDescent + return KLMinScoreGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st) + else + nothing + end end function output(alg::ParamSpaceSGD, state) @@ -104,9 +53,19 @@ function step( params = apply(operator, typeof(q), opt_st, params, re) avg_st = apply(averager, avg_st, params) - state = ParamSpaceSGDState( - prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st - ) + state = if alg isa KLMinRepGradDescent + KLMinRepGradDescentState(prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st) + elseif alg isa KLMinRepGradProxDescent + KLMinRepGradProxDescentState( + prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st + ) + elseif alg isa KLMinScoreGradDescent + KLMinScoreGradDescentState( + prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st + ) + else + nothing + end if !isnothing(callback) averaged_params = value(averager, avg_st) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 71c3e4fb..126dc2e4 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -9,12 +9,7 @@ (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) - obj = RepGradELBO(10) - - optimizer = Optimisers.Adam(1e-2) - averager = PolynomialAveraging() - - alg = ParamSpaceSGD(obj, AD, optimizer, averager, IdentityOperator()) + alg = KLMinRepGradDescent(AD; optimizer=Optimisers.Adam(1e-2), operator=ClipScale()) @testset "default_rng" begin optimize(alg, T, model, q0; show_progress=false)