Skip to content

Commit d6671ba

Browse files
committed
more fixes
1 parent 4425d08 commit d6671ba

File tree

4 files changed

+20
-22
lines changed

4 files changed

+20
-22
lines changed

src/model.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ fixed(model::Model) = fixed(model.context)
795795

796796
"""
797797
(model::Model)()
798-
(model::Model)(rng[, varinfo, sampler, context])
798+
(model::Model)(rng[, varinfo])
799799
800800
Sample from the `model` using the `sampler` with random number generator `rng`
801801
and the `context`, and store the sample and log joint probability in `varinfo`.
@@ -806,13 +806,8 @@ If no arguments are provided, uses the default random number generator and
806806
samples from the prior.
807807
"""
808808
(model::Model)() = model(Random.default_rng())
809-
function (model::Model)(
810-
rng::Random.AbstractRNG,
811-
varinfo::AbstractVarInfo=VarInfo(),
812-
sampler::AbstractSampler=SampleFromPrior(),
813-
)
814-
spl_ctx = SamplingContext(rng, sampler, DefaultContext())
815-
return first(evaluate!!(model, varinfo, spl_ctx))
809+
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
810+
return first(sample!!(rng, model, varinfo))
816811
end
817812

818813
"""
@@ -1016,7 +1011,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10161011
Generate a sample of type `T` from the prior distribution of the `model`.
10171012
"""
10181013
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1019-
x = last(sample!!(model, SimpleVarInfo{Float64}(OrderedDict())))
1014+
x = last(sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
10201015
return values_as(x, T)
10211016
end
10221017

@@ -1087,7 +1082,7 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
10871082
LogPriorAccumulator()
10881083
end
10891084
varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,))
1090-
return getlogprior(last(evaluate!!(model, varinfo, DefaultContext())))
1085+
return getlogprior(last(evaluate!!(model, varinfo)))
10911086
end
10921087

10931088
"""
@@ -1141,7 +1136,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
11411136
LogLikelihoodAccumulator()
11421137
end
11431138
varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,))
1144-
return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext())))
1139+
return getloglikelihood(last(evaluate!!(model, varinfo)))
11451140
end
11461141

11471142
"""
@@ -1195,7 +1190,7 @@ function predict(
11951190
return map(chain) do params_varinfo
11961191
vi = deepcopy(varinfo)
11971192
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
1198-
model(rng, vi, SampleFromPrior())
1193+
model(rng, vi)
11991194
return vi
12001195
end
12011196
end

src/threadsafe.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ end
116116
# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates
117117
# to define `getacc(vi)`.
118118
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
119-
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
119+
model = contextualize(
120+
model, setleafcontext(model.context, DynamicTransformationContext{false}())
121+
)
122+
return settrans!!(last(evaluate!!(model, vi)), t)
120123
end
121124

122125
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
123-
return settrans!!(
124-
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
125-
NoTransformation(),
126+
model = contextualize(
127+
model, setleafcontext(model.context, DynamicTransformationContext{true}())
126128
)
129+
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
127130
end
128131

129132
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)

src/transforming.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,17 @@ function _transform!!(
5151
vi::AbstractVarInfo,
5252
model::Model,
5353
)
54-
# To transform using DynamicTransformationContext, we evaluate the model, but we do not
55-
# need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of
56-
# the transformation).
54+
# To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context:
55+
model = contextualize(model, setleafcontext(model.context, ctx))
56+
# but we do not need to use any accumulators other than LogPriorAccumulator
57+
# (which is affected by the Jacobian of the transformation).
5758
accs = getaccs(vi)
5859
has_logprior = haskey(accs, Val(:LogPrior))
5960
if has_logprior
6061
old_logprior = getacc(accs, Val(:LogPrior))
6162
vi = setaccs!!(vi, (old_logprior,))
6263
end
63-
vi = settrans!!(last(evaluate!!(model, vi, ctx)), t)
64+
vi = settrans!!(last(evaluate!!(model, vi)), t)
6465
# Restore the accumulators.
6566
if has_logprior
6667
new_logprior = getacc(vi, Val(:LogPrior))

test/model.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
332332
@test logjoint(model, x) !=
333333
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
334334
# Ensure `varnames` is implemented.
335-
sampling_model = contextualize(model, SamplingContext(model.context))
336-
vi = last(DynamicPPL.evaluate!!(sampling_model, SimpleVarInfo(OrderedDict())))
335+
vi = last(DynamicPPL.sample!!(sampling_model, SimpleVarInfo(OrderedDict())))
337336
@test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model))
338337
# Ensure `posterior_mean` is implemented.
339338
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)

0 commit comments

Comments
 (0)