-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Description
I have tried to implement multithreaded sampling by changing:
function estimate_energy_with_samples(prob, samples)
#return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
logdensity_fn = Base.Fix1(LogDensityProblems.logdensity, prob)
return mean(fetch.([Threads.@spawn logdensity_fn(sample) for sample in eachsample(samples)]))
end
However, while this works when using the AutoForwardDiff()
AD backend, it fails (silently) when using Zygote
. I am guessing that this is due to Zygote not being thread safe here?
Code:
using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using ForwardDiff
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote
function double_normal()
return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end
@model function normal_model(data)
p1 ~ filldist(Normal(0.0, 1.0), 2)
p2 ~ Normal(0.0, 1.0)
ps = vcat(p1, p2)
for i in 1:size(data, 2)
data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
end
end
data = rand(double_normal(), 5)
model = normal_model(data)
##
d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy())
q, _, stats, _ = AdvancedVI.optimize(
ℓπ,
elbo,
q,
500;
adtype = AutoZygote(),
optimizer = optimizer,
)
##
using PyPlot
fig, ax = PyPlot.subplots()
elbo = [s.elbo for s in stats]
ax.plot(elbo)
fig
1. Zygote no threading
2. Zygote with threading
3. ForwardDiff with threading
Metadata
Metadata
Assignees
Labels
No labels