Skip to content

Commit a67d0ce

Browse files
torfjeldeyebai
andauthored
Use LogDensityFunction for variational inference (#1944)
* use LogDensityFunction for vi * bump minor version * removed undefined reference to Bijectors.setadbacked * removed incorrect previous commit * Update VariationalInference.jl --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 0dee0f8 commit a67d0ce

File tree

1 file changed

+13
-25
lines changed

1 file changed

+13
-25
lines changed

src/variational/VariationalInference.jl

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
module Variational
22

3-
using AdvancedVI
4-
using Bijectors
5-
using DistributionsAD
6-
using DynamicPPL
7-
using StatsBase
8-
using StatsFuns
3+
using DistributionsAD: DistributionsAD
4+
using DynamicPPL: DynamicPPL
5+
using StatsBase: StatsBase
6+
using StatsFuns: StatsFuns
7+
using LogDensityProblems: LogDensityProblems
98
using Distributions
109

11-
using Random
10+
using Random: Random
11+
12+
import AdvancedVI
13+
import Bijectors
14+
1215

1316
# Reexports
1417
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad
@@ -34,26 +37,11 @@ function make_logjoint(model::DynamicPPL.Model; weight = 1.0)
3437
DynamicPPL.DefaultContext(),
3538
weight
3639
)
37-
varinfo_init = DynamicPPL.VarInfo(model, ctx)
38-
39-
function logπ(z)
40-
varinfo = DynamicPPL.VarInfo(varinfo_init, DynamicPPL.SampleFromUniform(), z)
41-
model(varinfo)
42-
43-
return DynamicPPL.getlogp(varinfo)
44-
end
45-
46-
return logπ
40+
model_contextualized = DynamicPPL.contextualize(model, ctx)
41+
f = DynamicPPL.LogDensityFunction(model_contextualized)
42+
return Base.Fix1(LogDensityProblems.logdensity, f)
4743
end
4844

49-
function logjoint(model::DynamicPPL.Model, varinfo, z)
50-
varinfo = DynamicPPL.VarInfo(varinfo, DynamicPPL.SampleFromUniform(), z)
51-
model(varinfo)
52-
53-
return DynamicPPL.getlogp(varinfo)
54-
end
55-
56-
5745
# objectives
5846
function (elbo::ELBO)(
5947
rng::Random.AbstractRNG,

0 commit comments

Comments
 (0)