Skip to content

Commit 074a2ca

Browse files
committed
Improve performance of Prior()
1 parent d75e6f2 commit 074a2ca

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/mcmc/prior.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ struct Prior <: InferenceAlgorithm end
88
function AbstractMCMC.step(
99
rng::Random.AbstractRNG,
1010
model::DynamicPPL.Model,
11-
sampler::DynamicPPL.Sampler{<:Prior},
12-
state=nothing;
11+
sampler::DynamicPPL.Sampler{<:Prior};
1312
kwargs...,
1413
)
1514
vi = last(
@@ -19,7 +18,29 @@ function AbstractMCMC.step(
1918
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
2019
),
2120
)
22-
return vi, nothing
21+
vi = DynamicPPL.typed_varinfo(vi)
22+
return vi, vi
23+
end
24+
25+
function AbstractMCMC.step(
26+
rng::Random.AbstractRNG,
27+
model::DynamicPPL.Model,
28+
sampler::DynamicPPL.Sampler{<:Prior},
29+
vi::AbstractVarInfo;
30+
kwargs...,
31+
)
32+
# TODO(DPPL0.38/penelopeysm): Use InitContext.
33+
for vn in keys(vi)
34+
DynamicPPL.set_flag!(vi, vn, "del")
35+
end
36+
vi = last(
37+
DynamicPPL.evaluate!!(
38+
model,
39+
vi,
40+
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
41+
),
42+
)
43+
return vi, vi
2344
end
2445

2546
DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains

0 commit comments

Comments
 (0)