Skip to content

Commit 4425d08

Browse files
committed
some fixes
1 parent 15ac97b commit 4425d08

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

src/model.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -807,12 +807,12 @@ samples from the prior.
807807
"""
808808
(model::Model)() = model(Random.default_rng())
809809
function (model::Model)(
810-
rng::AbstractRNG,
810+
rng::Random.AbstractRNG,
811811
varinfo::AbstractVarInfo=VarInfo(),
812812
sampler::AbstractSampler=SampleFromPrior(),
813813
)
814814
spl_ctx = SamplingContext(rng, sampler, DefaultContext())
815-
return evaluate!!(model, varinfo, spl_ctx)
815+
return first(evaluate!!(model, varinfo, spl_ctx))
816816
end
817817

818818
"""
@@ -833,12 +833,15 @@ evaluation by wrapping the model's context in a `SamplingContext`.
833833
834834
Returns a tuple of the model's return value, plus the updated `varinfo` object.
835835
"""
836-
function sample!!(rng::AbstractRNG, model::Model, varinfo::AbstractVarInfo)
836+
function sample!!(rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo)
837837
sampling_model = contextualize(
838838
model, SamplingContext(rng, SampleFromPrior(), model.context)
839839
)
840840
return evaluate!!(sampling_model, varinfo)
841841
end
842+
function sample!!(model::Model, varinfo::AbstractVarInfo)
843+
return sample!!(Random.default_rng(), model, varinfo)
844+
end
842845

843846
"""
844847
evaluate!!(model::Model, varinfo)
@@ -978,7 +981,6 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
978981
# speeding up computation. See docs for `maybe_invlink_before_eval!!`
979982
# for more information.
980983
maybe_invlink_before_eval!!(varinfo, model),
981-
context_new,
982984
$(unwrap_args...),
983985
)
984986
kwargs = model.defaults
@@ -1014,15 +1016,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10141016
Generate a sample of type `T` from the prior distribution of the `model`.
10151017
"""
10161018
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1017-
x = last(
1018-
evaluate!!(
1019-
model,
1020-
SimpleVarInfo{Float64}(OrderedDict()),
1021-
# NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`,
1022-
# and b) avoid double-stacking the parent contexts.
1023-
SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)),
1024-
),
1025-
)
1019+
x = last(sample!!(model, SimpleVarInfo{Float64}(OrderedDict())))
10261020
return values_as(x, T)
10271021
end
10281022

@@ -1187,7 +1181,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
11871181
end
11881182

11891183
"""
1190-
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
1184+
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
11911185
11921186
Generate samples from the posterior predictive distribution by evaluating `model` at each set
11931187
of parameter values provided in `chain`. The number of posterior predictive samples matches

0 commit comments

Comments
 (0)