Skip to content

Commit 6f12082

Browse files
committed
fix
1 parent 3b77f89 commit 6f12082

File tree

3 files changed

+9
-19
lines changed

3 files changed

+9
-19
lines changed

src/pointwise_logdensities.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
44
An accumulator that stores the log-probabilities of each variable in a model.
55
6-
Internally this context stores the log-probabilities in a dictionary, where the keys are
7-
the variable names and the values are vectors of log-probabilities. Each element in a vector
8-
corresponds to one execution of the model.
6+
Internally this accumulator stores the log-probabilities in a dictionary, where
7+
the keys are the variable names and the values are vectors of
8+
log-probabilities. Each element in a vector corresponds to one execution of the
9+
model.
910
1011
`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies
1112
which log-probabilities to store in the accumulator. `KeyType` is the type by which variable
@@ -258,10 +259,8 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String)
258259
return pointwise_logdensities(model, chain, T, Val(:likelihood))
259260
end
260261

261-
function pointwise_loglikelihoods(
262-
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()
263-
)
264-
return pointwise_logdensities(model, varinfo, context, Val(:likelihood))
262+
function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
263+
return pointwise_logdensities(model, varinfo, Val(:likelihood))
265264
end
266265

267266
"""

src/sampler.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,21 @@ function AbstractMCMC.step(
6363
end
6464

6565
"""
66-
default_varinfo(rng, model, sampler[, context])
66+
default_varinfo(rng, model, sampler)
6767
6868
Return a default varinfo object for the given `model` and `sampler`.
6969
7070
# Arguments
7171
- `rng::Random.AbstractRNG`: Random number generator.
7272
- `model::Model`: Model for which we want to create a varinfo object.
7373
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
74-
- `context::AbstractContext`: Context in which the model is evaluated.
7574
7675
# Returns
7776
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
7877
"""
7978
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
80-
return default_varinfo(rng, model, sampler, DefaultContext())
81-
end
82-
function default_varinfo(
83-
rng::Random.AbstractRNG,
84-
model::Model,
85-
sampler::AbstractSampler,
86-
context::AbstractContext,
87-
)
8879
init_sampler = initialsampler(sampler)
89-
return typed_varinfo(rng, model, init_sampler, context)
80+
return typed_varinfo(rng, model, init_sampler)
9081
end
9182

9283
function AbstractMCMC.sample(

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
591591
xs_train = 1:0.1:10
592592
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
593593
m_lin_reg = linear_reg(xs_train, ys_train)
594-
chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000]
594+
chain = [last(evaluate!!(m_lin_reg, VarInfo())) for _ in 1:10000]
595595

596596
# chain is generated from the prior
597597
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

0 commit comments

Comments
 (0)