Skip to content

Commit 918f6f4

Browse files
torfjeldeyebai
andauthored
Preserve context in LogDensityFunction (#1943)
* preserve context when constructing LogDensityFunction for some samplers * bump patch version * added tests for prior sampling using NUTS * added test * fixed atol for test * fixed MH tests * fixed HMC tests --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 3e71a76 commit 918f6f4

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.24.3"
3+
version = "0.24.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/inference/hmc.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,14 @@ function DynamicPPL.initialstep(
159159
metricT = getmetricT(spl.alg)
160160
metric = metricT(length(theta))
161161
= LogDensityProblemsAD.ADgradient(
162-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
162+
Turing.LogDensityFunction(
163+
vi,
164+
model,
165+
# Use the leaf-context from the `model` in case the user has
166+
# contextualized the model with something like `PriorContext`
167+
# to sample from the prior.
168+
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
169+
)
163170
)
164171
logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
165172
∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x)
@@ -265,7 +272,11 @@ end
265272
function get_hamiltonian(model, spl, vi, state, n)
266273
metric = gen_metric(n, spl, state)
267274
= LogDensityProblemsAD.ADgradient(
268-
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
275+
Turing.LogDensityFunction(
276+
vi,
277+
model,
278+
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context))
279+
)
269280
)
270281
ℓπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
271282
∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ)
@@ -538,7 +549,12 @@ function HMCState(
538549

539550
# Get the initial log pdf and gradient functions.
540551
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
541-
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
552+
logπ = Turing.LogDensityFunction(
553+
vi,
554+
model,
555+
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
556+
)
557+
542558

543559
# Get the metric type.
544560
metricT = getmetricT(spl.alg)

src/inference/mh.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,14 @@ function propose!!(
375375

376376
# Make a new transition.
377377
densitymodel = AMH.DensityModel(
378-
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
378+
Base.Fix1(
379+
LogDensityProblems.logdensity,
380+
Turing.LogDensityFunction(
381+
vi,
382+
model,
383+
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
384+
)
385+
)
379386
)
380387
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
381388

@@ -403,7 +410,14 @@ function propose!!(
403410

404411
# Make a new transition.
405412
densitymodel = AMH.DensityModel(
406-
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
413+
Base.Fix1(
414+
LogDensityProblems.logdensity,
415+
Turing.LogDensityFunction(
416+
vi,
417+
model,
418+
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
419+
)
420+
)
407421
)
408422
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
409423

test/inference/hmc.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,11 @@
216216
res3 = sample(StableRNG(123), gdemo_default, alg, 1000)
217217
@test Array(res1) == Array(res2) == Array(res3)
218218
end
219+
220+
@turing_testset "prior" begin
221+
alg = NUTS(1000, 0.8)
222+
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
223+
chain = sample(gdemo_default_prior, alg, 10_000)
224+
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.2)
225+
end
219226
end

test/inference/mh.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,16 @@
216216
vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default)
217217
@test !DynamicPPL.islinked(vi, spl)
218218
end
219+
220+
@turing_testset "prior" begin
221+
# HACK: MH can be so bad for this prior model for some reason that it's difficult to
222+
# find a non-trivial `atol` where the tests will pass for all seeds. Hence we fix it :/
223+
rng = StableRNG(10)
224+
alg = MH()
225+
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
226+
burnin = 10_000
227+
n = 10_000
228+
chain = sample(rng, gdemo_default_prior, alg, n; discard_initial = burnin, thinning=10)
229+
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3)
230+
end
219231
end

0 commit comments

Comments
 (0)