Skip to content

Commit 8fa3908

Browse files
committed
more fixes
1 parent d6671ba commit 8fa3908

File tree

8 files changed

+37
-44
lines changed

8 files changed

+37
-44
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
3131

3232
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3333
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
model, varinfo; only_ddpl
34+
sampling_model, varinfo; only_ddpl
3535
)
3636

3737
if !issuccess
@@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4646
else
4747
# Warn the user that we can't use the type stable one.
4848
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(model)
49+
DynamicPPL.untyped_varinfo(sampling_model)
5050
end
5151
end
5252

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function DynamicPPL.predict(
115115
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116116
predictive_samples = map(iters) do (sample_idx, chain_idx)
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
model(rng, varinfo, DynamicPPL.SampleFromPrior())
118+
varinfo = last(DynamicPPL.sample!!(rng, model, varinfo))
119119

120120
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(

src/logdensityfunction.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818
"""
1919
LogDensityFunction(
2020
model::Model,
21-
varinfo::AbstractVarInfo=VarInfo(model),
21+
varinfo::AbstractVarInfo=VarInfo(model);
2222
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2323
)
2424
@@ -106,7 +106,7 @@ struct LogDensityFunction{
106106

107107
function LogDensityFunction(
108108
model::Model,
109-
varinfo::AbstractVarInfo=VarInfo(model),
109+
varinfo::AbstractVarInfo=VarInfo(model);
110110
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
111111
)
112112
if adtype === nothing

src/model.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,7 @@ julia> # Now `a.x` will be sampled.
794794
fixed(model::Model) = fixed(model.context)
795795

796796
"""
797-
(model::Model)()
798-
(model::Model)(rng[, varinfo])
797+
(model::Model)([rng[, varinfo]])
799798
800799
Sample from the `model` using the `sampler` with random number generator `rng`
801800
and the `context`, and store the sample and log joint probability in `varinfo`.
@@ -805,10 +804,12 @@ Returns the model's return value.
805804
If no arguments are provided, uses the default random number generator and
806805
samples from the prior.
807806
"""
808-
(model::Model)() = model(Random.default_rng())
809807
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
810808
return first(sample!!(rng, model, varinfo))
811809
end
810+
function (model::Model)(varinfo::AbstractVarInfo=VarInfo())
811+
return model(Random.default_rng(), varinfo)
812+
end
812813

813814
"""
814815
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
@@ -821,21 +822,29 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
821822
end
822823

823824
"""
824-
sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo)
825+
sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
825826
826827
Evaluate the `model` with the given `varinfo`, but perform sampling during the
827-
evaluation by wrapping the model's context in a `SamplingContext`.
828+
evaluation using the given `sampler` by wrapping the model's context in a
829+
`SamplingContext`.
830+
831+
If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).
828832
829833
Returns a tuple of the model's return value, plus the updated `varinfo` object.
830834
"""
831-
function sample!!(rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo)
832-
sampling_model = contextualize(
833-
model, SamplingContext(rng, SampleFromPrior(), model.context)
834-
)
835+
function sample!!(
836+
rng::Random.AbstractRNG,
837+
model::Model,
838+
varinfo::AbstractVarInfo,
839+
sampler::AbstractSampler=SampleFromPrior(),
840+
)
841+
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
835842
return evaluate!!(sampling_model, varinfo)
836843
end
837-
function sample!!(model::Model, varinfo::AbstractVarInfo)
838-
return sample!!(Random.default_rng(), model, varinfo)
844+
function sample!!(
845+
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
846+
)
847+
return sample!!(Random.default_rng(), model, varinfo, sampler)
839848
end
840849

841850
"""
@@ -1028,7 +1037,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m
10281037
See [`logprior`](@ref) and [`loglikelihood`](@ref).
10291038
"""
10301039
function logjoint(model::Model, varinfo::AbstractVarInfo)
1031-
return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext())))
1040+
return getlogjoint(last(evaluate!!(model, varinfo)))
10321041
end
10331042

10341043
"""

src/test_utils/ad.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
6060
model::Model
6161
"The VarInfo that was used"
6262
varinfo::AbstractVarInfo
63-
"The evaluation context that was used"
64-
context::AbstractContext
6563
"The values at which the model was evaluated"
6664
params::Vector{Tparams}
6765
"The AD backend that was tested"
@@ -92,7 +90,6 @@ end
9290
grad_atol=1e-6,
9391
varinfo::AbstractVarInfo=link(VarInfo(model), model),
9492
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
95-
context::AbstractContext=DefaultContext(),
9693
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
9794
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
9895
verbose=true,
@@ -146,13 +143,7 @@ Everything else is optional, and can be categorised into several groups:
146143
prep_params)`. You could then evaluate the gradient at a different set of
147144
parameters using the `params` keyword argument.
148145
149-
3. _How to specify the evaluation context._
150-
151-
A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
152-
argument to control the evaluation context. This defaults to
153-
`DefaultContext()`.
154-
155-
4. _How to specify the results to compare against._ (Only if `test=true`.)
146+
3. _How to specify the results to compare against._ (Only if `test=true`.)
156147
157148
Once logp and its gradient has been calculated with the specified `adtype`,
158149
it must be tested for correctness.
@@ -167,12 +158,12 @@ Everything else is optional, and can be categorised into several groups:
167158
The default reference backend is ForwardDiff. If none of these parameters are
168159
specified, ForwardDiff will be used to calculate the ground truth.
169160
170-
5. _How to specify the tolerances._ (Only if `test=true`.)
161+
4. _How to specify the tolerances._ (Only if `test=true`.)
171162
172163
The tolerances for the value and gradient can be set using `value_atol` and
173164
`grad_atol`. These default to 1e-6.
174165
175-
6. _Whether to output extra logging information._
166+
5. _Whether to output extra logging information._
176167
177168
By default, this function prints messages when it runs. To silence it, set
178169
`verbose=false`.
@@ -195,7 +186,6 @@ function run_ad(
195186
grad_atol::AbstractFloat=1e-6,
196187
varinfo::AbstractVarInfo=link(VarInfo(model), model),
197188
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
198-
context::AbstractContext=DefaultContext(),
199189
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
200190
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
201191
verbose=true,
@@ -207,7 +197,7 @@ function run_ad(
207197

208198
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
209199
verbose && println(" params : $(params)")
210-
ldf = LogDensityFunction(model, varinfo, context; adtype=adtype)
200+
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
211201

212202
value, grad = logdensity_and_gradient(ldf, params)
213203
grad = collect(grad)
@@ -216,7 +206,7 @@ function run_ad(
216206
if test
217207
# Calculate ground truth to compare against
218208
value_true, grad_true = if expected_value_and_grad === nothing
219-
ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype)
209+
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
220210
logdensity_and_gradient(ldf_reference, params)
221211
else
222212
expected_value_and_grad
@@ -245,7 +235,6 @@ function run_ad(
245235
return ADResult(
246236
model,
247237
varinfo,
248-
context,
249238
params,
250239
adtype,
251240
value_atol,

test/compiler.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,7 @@ module Issue537 end
185185
@model function testmodel_missing3(x)
186186
x[1] ~ Bernoulli(0.5)
187187
global varinfo_ = __varinfo__
188-
global sampler_ = __context__.sampler
189188
global model_ = __model__
190-
global context_ = __context__
191-
global rng_ = __context__.rng
192189
global lp = getlogjoint(__varinfo__)
193190
return x
194191
end
@@ -197,17 +194,14 @@ module Issue537 end
197194
@test getlogjoint(varinfo) == lp
198195
@test varinfo_ isa AbstractVarInfo
199196
@test model_ === model
200-
@test context_ isa SamplingContext
201-
@test rng_ isa Random.AbstractRNG
197+
@test model_.context isa SamplingContext
198+
@test model_.context.rng isa Random.AbstractRNG
202199

203200
# disable warnings
204201
@model function testmodel_missing4(x)
205202
x[1] ~ Bernoulli(0.5)
206203
global varinfo_ = __varinfo__
207-
global sampler_ = __context__.sampler
208204
global model_ = __model__
209-
global context_ = __context__
210-
global rng_ = __context__.rng
211205
global lp = getlogjoint(__varinfo__)
212206
return x
213207
end false

test/ext/DynamicPPLJETExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
@testset "demo models" begin
6464
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
65+
sampling_model = contextualize(model, SamplingContext(model.context))
6566
# Use debug logging below.
6667
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
6768
# Check that the inferred varinfo is indeed suitable for evaluation and sampling
@@ -71,7 +72,7 @@
7172
JET.test_call(f_eval, argtypes_eval)
7273

7374
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
74-
model, varinfo, DynamicPPL.SamplingContext()
75+
sampling_model, varinfo
7576
)
7677
JET.test_call(f_sample, argtypes_sample)
7778
# For our demo models, they should all result in typed.
@@ -85,7 +86,7 @@
8586
)
8687
JET.test_call(f_eval, argtypes_eval)
8788
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
88-
model, typed_vi, DynamicPPL.SamplingContext()
89+
sampling_model, typed_vi
8990
)
9091
JET.test_call(f_sample, argtypes_sample)
9192
end

test/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ end
491491
# Check that instantiating the model does not perform linking
492492
vi = VarInfo()
493493
meta = vi.metadata
494-
model(vi, SampleFromUniform())
494+
model(vi)
495495
@test all(x -> !istrans(vi, x), meta.vns)
496496

497497
# Check that linking and invlinking set the `trans` flag accordingly

0 commit comments

Comments
 (0)