Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ include("optimize.jl")

## Parameter Space SGD
include("algorithms/paramspacesgd/abstractobjective.jl")
include("algorithms/paramspacesgd/paramspacesgd.jl")

export ParamSpaceSGD

Expand Down Expand Up @@ -319,6 +318,7 @@ export RepGradELBO,
SubsampledObjective

include("algorithms/paramspacesgd/constructors.jl")
include("algorithms/paramspacesgd/paramspacesgd.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

Expand Down
140 changes: 136 additions & 4 deletions src/algorithms/paramspacesgd/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,53 @@ KL divergence minimization by running stochastic gradient descent with the repar
- `operator::AbstractOperator`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy`.
"""
struct KLMinRepGradDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

struct KLMinRepGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
prob::P
q::Q
iteration::Int
grad_buf::GradBuf
opt_st::OptSt
obj_st::ObjSt
avg_st::AvgSt
end

function KLMinRepGradDescent(
adtype::ADTypes.AbstractADType;
entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(),
Expand All @@ -39,7 +79,11 @@ function KLMinRepGradDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const ADVI = KLMinRepGradDescent
Expand All @@ -63,12 +107,52 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
- `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The variational family is `MvLocationScale`.
- The target distribution and the variational approximation have the same support.
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
- Additonal requirements on `q` may apply depending on the choice of `entropy_zerograd`.
"""
struct KLMinRepGradProxDescent{
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:ProximalLocationScaleEntropy,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

struct KLMinRepGradProxDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
prob::P
q::Q
iteration::Int
grad_buf::GradBuf
opt_st::OptSt
obj_st::ObjSt
avg_st::AvgSt
end

function KLMinRepGradProxDescent(
adtype::ADTypes.AbstractADType;
entropy_zerograd::Union{
Expand All @@ -85,7 +169,11 @@ function KLMinRepGradProxDescent(
else
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy_zerograd), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinRepGradProxDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

"""
Expand All @@ -106,15 +194,55 @@ KL divergence minimization by running stochastic gradient descent with the score
- `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)

# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

# Requirements
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
- The variational approximation ``q_{\\lambda}`` implements `rand`.
- The variational approximation ``q_{\\lambda}`` implements `logpdf(q, x)`, which should also be differentiable with respect to `x`.
- The target distribution and the variational approximation have the same support.
"""
struct KLMinScoreGradDescent{
Obj<:Union{<:ScoreGradELBO,<:SubsampledObjective},
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

struct KLMinScoreGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
prob::P
q::Q
iteration::Int
grad_buf::GradBuf
opt_st::OptSt
obj_st::ObjSt
avg_st::AvgSt
end

function KLMinScoreGradDescent(
adtype::ADTypes.AbstractADType;
optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(),
optimizer::Optimisers.AbstractRule=DoWG(),
n_samples::Int=1,
averager::AbstractAverager=PolynomialAveraging(),
operator::AbstractOperator=IdentityOperator(),
Expand All @@ -125,7 +253,11 @@ function KLMinScoreGradDescent(
else
SubsampledObjective(ScoreGradELBO(n_samples), subsampling)
end
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
return KLMinScoreGradDescent{
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
}(
objective, adtype, optimizer, averager, operator
)
end

const BBVI = KLMinScoreGradDescent
95 changes: 27 additions & 68 deletions src/algorithms/paramspacesgd/paramspacesgd.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,7 @@

"""
ParamSpaceSGD(
objective::AbstractVariationalObjective,
adtype::ADTypes.AbstractADType,
optimizer::Optimisers.AbstractRule,
averager::AbstractAverager,
operator::AbstractOperator,
)

This algorithm applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters.

The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
This requires the variational approximation to be marked as a functor through `Functors.@functor`.

!!! note
Different objective may impose different requirements on `adtype`, variational family, `optimizer`, and `operator`. It is therefore important to check the documentation corresponding to each specific objective. Essentially, each objective should be thought as forming its own unique algorithm.

# Arguments
- `objective`: Variational Objective.
- `adtype`: Automatic differentiation backend.
- `optimizer`: Optimizer used for inference.
- `averager` : Parameter averaging strategy.
- `operator` : Operator applied to the parameters after each optimization step.

# Output
- `q_averaged`: The variational approximation formed from the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
- `params`: Current variational parameters.
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
- `gradient`: The estimated (possibly stochastic) gradient.

"""
struct ParamSpaceSGD{
Obj<:AbstractVariationalObjective,
AD<:ADTypes.AbstractADType,
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
averager::Avg
operator::Op
end

struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
prob::P
q::Q
iteration::Int
grad_buf::GradBuf
opt_st::OptSt
obj_st::ObjSt
avg_st::AvgSt
end
const ParamSpaceSGD = Union{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const ParamSpaceSGD = Union{
"""
This family of algorithms (`<:KLMinRepGradDescent`,`<:KLMinRepGradProxDescent`,`<:KLMinScoreGradDescent`) applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters.
The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
This requires the variational approximation to be marked as a functor through `Functors.@functor`.
"""
const ParamSpaceSGD = Union{

<:KLMinRepGradDescent,<:KLMinRepGradProxDescent,<:KLMinScoreGradDescent
}

function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
(; adtype, optimizer, averager, objective, operator) = alg
Expand All @@ -76,7 +15,17 @@ function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
obj_st = init(rng, objective, adtype, q_init, prob, params, re)
avg_st = init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
if alg isa KLMinRepGradDescent
return KLMinRepGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
elseif alg isa KLMinRepGradProxDescent
return KLMinRepGradProxDescentState(
prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st
)
elseif alg isa KLMinScoreGradDescent
return KLMinScoreGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
else
nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe throw a warning or error message here instead of letting it fail silently?

Suggested change
nothing
nothing

Copy link
Member Author

@Red-Portal Red-Portal Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never hit the else condition, so let me use InvalidStateException.

end
end

function output(alg::ParamSpaceSGD, state)
Expand Down Expand Up @@ -104,9 +53,19 @@ function step(
params = apply(operator, typeof(q), opt_st, params, re)
avg_st = apply(averager, avg_st, params)

state = ParamSpaceSGDState(
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
)
state = if alg isa KLMinRepGradDescent
KLMinRepGradDescentState(prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st)
elseif alg isa KLMinRepGradProxDescent
KLMinRepGradProxDescentState(
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
)
elseif alg isa KLMinScoreGradDescent
KLMinScoreGradDescentState(
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
)
else
nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Suggested change
nothing
nothing

end

if !isnothing(callback)
averaged_params = value(averager, avg_st)
Expand Down
7 changes: 1 addition & 6 deletions test/general/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats

q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims)))
obj = RepGradELBO(10)

optimizer = Optimisers.Adam(1e-2)
averager = PolynomialAveraging()

alg = ParamSpaceSGD(obj, AD, optimizer, averager, IdentityOperator())
alg = KLMinRepGradDescent(AD; optimizer=Optimisers.Adam(1e-2), operator=ClipScale())

@testset "default_rng" begin
optimize(alg, T, model, q0; show_progress=false)
Expand Down
Loading