Skip to content

Commit cb3b838

Browse files
authored
fix avoid re-defining the differentiation objective to support AD pre-compilation (#66)
* update interface for objective initialization * improve `RepGradELBO` to not redefine AD forward path * add auxiliary argument to `value_and_gradient!`
1 parent c93b5d7 commit cb3b838

File tree

10 files changed

+146
-44
lines changed

10 files changed

+146
-44
lines changed

ext/AdvancedVIBijectorsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ function AdvancedVI.reparam_with_entropy(
4242
n_samples::Int,
4343
ent_est ::AdvancedVI.AbstractEntropyEstimator
4444
)
45-
transform = q.transform
46-
q_unconst = q.dist
47-
q_unconst_stop = q_stop.dist
45+
transform = q.transform
46+
q_unconst = q.dist
47+
q_unconst_stop = q_stop.dist
4848

4949
# Draw samples and compute entropy of the uncontrained distribution
5050
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(

ext/AdvancedVIForwardDiffExt.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,29 @@ end
1414
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
1515

1616
function AdvancedVI.value_and_gradient!(
17-
ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
18-
) where {T<:Real}
17+
ad ::ADTypes.AutoForwardDiff,
18+
f,
19+
x ::AbstractVector{<:Real},
20+
out::DiffResults.MutableDiffResult
21+
)
1922
chunk_size = getchunksize(ad)
2023
config = if isnothing(chunk_size)
21-
ForwardDiff.GradientConfig(f, θ)
24+
ForwardDiff.GradientConfig(f, x)
2225
else
23-
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
26+
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
2427
end
25-
ForwardDiff.gradient!(out, f, θ, config)
28+
ForwardDiff.gradient!(out, f, x, config)
2629
return out
2730
end
2831

32+
function AdvancedVI.value_and_gradient!(
33+
ad ::ADTypes.AutoForwardDiff,
34+
f,
35+
x ::AbstractVector,
36+
aux,
37+
out::DiffResults.MutableDiffResult
38+
)
39+
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
40+
end
41+
2942
end

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,24 @@ end
1313

1414
# ReverseDiff without compiled tape
1515
function AdvancedVI.value_and_gradient!(
16-
ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
16+
ad::ADTypes.AutoReverseDiff,
17+
f,
18+
x::AbstractVector{<:Real},
19+
out::DiffResults.MutableDiffResult
1720
)
18-
tp = ReverseDiff.GradientTape(f, θ)
19-
ReverseDiff.gradient!(out, tp, θ)
21+
tp = ReverseDiff.GradientTape(f, x)
22+
ReverseDiff.gradient!(out, tp, x)
2023
return out
2124
end
2225

26+
function AdvancedVI.value_and_gradient!(
27+
ad::ADTypes.AutoReverseDiff,
28+
f,
29+
x::AbstractVector{<:Real},
30+
aux,
31+
out::DiffResults.MutableDiffResult
32+
)
33+
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
34+
end
35+
2336
end

ext/AdvancedVIZygoteExt.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,36 @@ module AdvancedVIZygoteExt
44
if isdefined(Base, :get_extension)
55
using AdvancedVI
66
using AdvancedVI: ADTypes, DiffResults
7+
using ChainRulesCore
78
using Zygote
89
else
910
using ..AdvancedVI
1011
using ..AdvancedVI: ADTypes, DiffResults
12+
using ..ChainRulesCore
1113
using ..Zygote
1214
end
1315

1416
function AdvancedVI.value_and_gradient!(
15-
ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
17+
::ADTypes.AutoZygote,
18+
f,
19+
x::AbstractVector{<:Real},
20+
out::DiffResults.MutableDiffResult
1621
)
17-
y, back = Zygote.pullback(f, θ)
18-
θ = back(one(y))
22+
y, back = Zygote.pullback(f, x)
23+
x = back(one(y))
1924
DiffResults.value!(out, y)
20-
DiffResults.gradient!(out, only(∇θ))
25+
DiffResults.gradient!(out, only(∇x))
2126
return out
2227
end
2328

29+
function AdvancedVI.value_and_gradient!(
30+
ad::ADTypes.AutoZygote,
31+
f,
32+
x::AbstractVector{<:Real},
33+
aux,
34+
out::DiffResults.MutableDiffResult
35+
)
36+
AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
37+
end
38+
2439
end

src/AdvancedVI.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,35 @@ using StatsBase
2525

2626
# derivatives
2727
"""
28-
value_and_gradient!(ad, f, θ, out)
28+
value_and_gradient!(ad, f, x, out)
29+
value_and_gradient!(ad, f, x, aux, out)
2930
30-
Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`.
31+
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
32+
`f` may receive auxiliary input as `f(x,aux)`.
3133
3234
# Arguments
3335
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
3436
- `f`: Function subject to differentiation.
35-
- `θ`: The point to evaluate the gradient.
37+
- `x`: The point to evaluate the gradient.
38+
- `aux`: Auxiliary input passed to `f`.
3639
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
3740
"""
3841
function value_and_gradient! end
3942

43+
"""
44+
stop_gradient(x)
45+
46+
Stop the gradient from propagating to `x` if the selected ad backend supports it.
47+
Otherwise, it is equivalent to `identity`.
48+
49+
# Arguments
50+
- `x`: Input
51+
52+
# Returns
53+
- `x`: Same value as the input.
54+
"""
55+
function stop_gradient end
56+
4057
# Update for gradient descent step
4158
"""
4259
update_variational_params!(family_type, opt_st, params, restructure, grad)
@@ -78,22 +95,23 @@ If the estimator is stateful, it can implement `init` to initialize the state.
7895
abstract type AbstractVariationalObjective end
7996

8097
"""
81-
init(rng, obj, λ, restructure)
98+
init(rng, obj, prob, params, restructure)
8299
83100
Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
84101
This function needs to be implemented only if `obj` is stateful.
85102
86103
# Arguments
87104
- `rng::Random.AbstractRNG`: Random number generator.
88105
- `obj::AbstractVariationalObjective`: Variational objective.
89-
- `λ`: Initial variational parameters.
106+
- `params`: Initial variational parameters.
90107
- `restructure`: Function that reconstructs the variational approximation from `λ`.
91108
"""
92109
init(
93110
::Random.AbstractRNG,
94111
::AbstractVariationalObjective,
95-
::AbstractVector,
96-
::Any
112+
::Any,
113+
::Any,
114+
::Any,
97115
) = nothing
98116

99117
"""

src/objectives/elbo/repgradelbo.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,13 @@ function estimate_energy_with_samples(prob, samples)
5656
end
5757

5858
"""
59-
reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
59+
reparam_with_entropy(rng, q, n_samples, ent_est)
6060
6161
Draw `n_samples` from `q` and compute its entropy.
6262
6363
# Arguments
6464
- `rng::Random.AbstractRNG`: Random number generator.
6565
- `q`: Variational approximation.
66-
- `q_stop`: `q` but with its gradient stopped.
6766
- `n_samples::Int`: Number of Monte Carlo samples
6867
- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
6968
@@ -72,7 +71,11 @@ Draw `n_samples` from `q` and compute its entropy.
7271
- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
7372
"""
7473
function reparam_with_entropy(
75-
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
74+
rng ::Random.AbstractRNG,
75+
q,
76+
q_stop,
77+
n_samples::Int,
78+
ent_est ::AbstractEntropyEstimator
7679
)
7780
samples = rand(rng, q, n_samples)
7881
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
@@ -94,28 +97,31 @@ end
9497
estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
9598
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
9699

100+
function estimate_repgradelbo_ad_forward(params′, aux)
101+
@unpack rng, obj, problem, restructure, q_stop = aux
102+
q = restructure(params′)
103+
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
104+
energy = estimate_energy_with_samples(problem, samples)
105+
elbo = energy + entropy
106+
-elbo
107+
end
108+
97109
function estimate_gradient!(
98110
rng ::Random.AbstractRNG,
99111
obj ::RepGradELBO,
100112
adtype::ADTypes.AbstractADType,
101113
out ::DiffResults.MutableDiffResult,
102114
prob,
103-
λ,
115+
params,
104116
restructure,
105117
state,
106118
)
107-
q_stop = restructure(λ)
108-
function f(λ′)
109-
q = restructure(λ′)
110-
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
111-
energy = estimate_energy_with_samples(prob, samples)
112-
elbo = energy + entropy
113-
-elbo
114-
end
115-
value_and_gradient!(adtype, f, λ, out)
116-
119+
q_stop = restructure(params)
120+
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
121+
value_and_gradient!(
122+
adtype, estimate_repgradelbo_ad_forward, params, aux, out
123+
)
117124
nelbo = DiffResults.value(out)
118125
stat = (elbo=-nelbo,)
119-
120126
out, nothing, stat
121127
end

src/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function optimize(
6666
)
6767
params, restructure = Optimisers.destructure(deepcopy(q_init))
6868
opt_st = maybe_init_optimizer(state_init, optimizer, params)
69-
obj_st = maybe_init_objective(state_init, rng, objective, params, restructure)
69+
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
7070
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
7171
stats = NamedTuple[]
7272

src/utils.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,28 @@ end
66
function maybe_init_optimizer(
77
state_init::NamedTuple,
88
optimizer ::Optimisers.AbstractRule,
9-
params ::AbstractVector
9+
params
1010
)
11-
haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params)
11+
if haskey(state_init, :optimizer)
12+
state_init.optimizer
13+
else
14+
Optimisers.setup(optimizer, params)
15+
end
1216
end
1317

1418
function maybe_init_objective(
1519
state_init::NamedTuple,
1620
rng ::Random.AbstractRNG,
1721
objective ::AbstractVariationalObjective,
18-
params ::AbstractVector,
22+
problem,
23+
params,
1924
restructure
2025
)
21-
haskey(state_init, :objective) ? state_init.objective : init(rng, objective, params, restructure)
26+
if haskey(state_init, :objective)
27+
state_init.objective
28+
else
29+
init(rng, objective, problem, params, restructure)
30+
end
2231
end
2332

2433
eachsample(samples::AbstractMatrix) = eachcol(samples)

test/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
4+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
45
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
56
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
6-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -26,7 +26,6 @@ ADTypes = "0.2.1, 1"
2626
Bijectors = "0.13"
2727
Distributions = "0.25.100"
2828
DistributionsAD = "0.6.45"
29-
Enzyme = "0.12"
3029
FillArrays = "1.6.1"
3130
ForwardDiff = "0.10.36"
3231
Functors = "0.4.5"

test/interface/repgradelbo.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,32 @@ using Test
2626
@test elbo elbo_ref rtol=0.1
2727
end
2828
end
29+
30+
@testset "interface RepGradELBO STL variance reduction" begin
31+
seed = (0x38bef07cf9cc549d)
32+
rng = StableRNG(seed)
33+
34+
modelstats = normal_meanfield(rng, Float64)
35+
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
36+
37+
@testset for ad in [
38+
ADTypes.AutoForwardDiff(),
39+
ADTypes.AutoReverseDiff(),
40+
ADTypes.AutoZygote()
41+
]
42+
q_true = MeanFieldGaussian(
43+
Vector{eltype(μ_true)}(μ_true),
44+
Diagonal(Vector{eltype(L_true)}(diag(L_true)))
45+
)
46+
params, re = Optimisers.destructure(q_true)
47+
obj = RepGradELBO(10; entropy=StickingTheLandingEntropy())
48+
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))
49+
50+
aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true)
51+
AdvancedVI.value_and_gradient!(
52+
ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
53+
)
54+
grad = DiffResults.gradient(out)
55+
@test norm(grad) 0 atol=1e-5
56+
end
57+
end

0 commit comments

Comments
 (0)