Skip to content

Commit 7527e2f

Browse files
committed
Excise SamplingContext tests
1 parent 0a670fd commit 7527e2f

File tree

8 files changed

+25
-70
lines changed

8 files changed

+25
-70
lines changed

docs/src/api.md

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -456,33 +456,24 @@ AbstractPPL.evaluate!!
456456

457457
This method mutates the `varinfo` used for execution.
458458
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
459-
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
460-
461-
```@docs
462-
DynamicPPL.evaluate_and_sample!!
463-
```
464459

465460
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
466461
Contexts are subtypes of `AbstractPPL.AbstractContext`.
467462

468463
```@docs
469-
SamplingContext
470464
DefaultContext
471465
PrefixContext
472466
ConditionContext
467+
InitContext
473468
```
474469

475-
### Samplers
470+
### VarInfo initialisation
476471

477-
In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
478-
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
472+
TODO
479473

480-
```@docs
481-
SampleFromPrior
482-
SampleFromUniform
483-
```
474+
### Samplers
484475

485-
Additionally, a generic sampler for inference is implemented.
476+
In DynamicPPL a generic sampler for inference is implemented.
486477

487478
```@docs
488479
Sampler
@@ -493,7 +484,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
493484
```@docs
494485
DynamicPPL.initialstep
495486
DynamicPPL.loadstate
496-
DynamicPPL.initialsampler
487+
DynamicPPL.init_strategy
497488
```
498489

499490
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ else
88
using ..EnzymeCore
99
end
1010

11-
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true
12-
1311
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
1412
# only checks whether such a method exists, and never runs it.
1513
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =

ext/DynamicPPLJETExt.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,12 @@ end
2121
function DynamicPPL.Experimental._determine_varinfo_jet(
2222
model::DynamicPPL.Model; only_ddpl::Bool=true
2323
)
24-
# Use SamplingContext to test type stability.
25-
sampling_model = DynamicPPL.contextualize(
26-
model, DynamicPPL.SamplingContext(model.context)
27-
)
28-
2924
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(sampling_model)
25+
varinfo = DynamicPPL.typed_varinfo(model)
3126

3227
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3328
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
sampling_model, varinfo; only_ddpl
29+
model, varinfo; only_ddpl
3530
)
3631

3732
if !issuccess
@@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4641
else
4742
# Warn the user that we can't use the type stable one.
4843
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(sampling_model)
44+
DynamicPPL.untyped_varinfo(model)
5045
end
5146
end
5247

test/ad.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
112112
# Compiling the ReverseDiff tape used to fail here
113113
spl = Sampler(MyEmptyAlg())
114114
vi = VarInfo(model)
115-
sampling_model = contextualize(model, SamplingContext(model.context))
116-
ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true))
115+
ldf = LogDensityFunction(model, vi; adtype=AutoReverseDiff(; compile=true))
117116
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
118117
end
119118

test/compiler.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,8 @@ module Issue537 end
193193
varinfo = VarInfo(model)
194194
@test getlogjoint(varinfo) == lp
195195
@test varinfo_ isa AbstractVarInfo
196-
# During the model evaluation, its context is wrapped in a
197-
# SamplingContext, so `model_` is not going to be equal to `model`.
198-
# We can still check equality of `f` though.
199196
@test model_.f === model.f
200-
@test model_.context isa SamplingContext
197+
@test model_.context isa InitContext
201198
@test model_.context.rng isa Random.AbstractRNG
202199

203200
# disable warnings

test/contexts.jl

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
4949
contexts = Dict(
5050
:default => DefaultContext(),
5151
:testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()),
52-
:sampling => SamplingContext(),
5352
:prefix => PrefixContext(@varname(x)),
5453
:condition1 => ConditionContext((x=1.0,)),
5554
:condition2 => ConditionContext(
@@ -150,11 +149,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
150149
vn = @varname(x[1])
151150
ctx1 = PrefixContext(@varname(a))
152151
@test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1])
153-
ctx2 = SamplingContext(ctx1)
152+
ctx2 = ConditionContext(Dict(), ctx1)
154153
@test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1])
155154
ctx3 = PrefixContext(@varname(b), ctx2)
156155
@test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1])
157-
ctx4 = DynamicPPL.SamplingContext(ctx3)
156+
ctx4 = FixedContext(Dict(), ctx3)
158157
@test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1])
159158
end
160159

@@ -165,29 +164,28 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
165164
@test new_vn == @varname(a.x[1])
166165
@test new_ctx == DefaultContext()
167166

168-
ctx2 = SamplingContext(PrefixContext(@varname(a)))
167+
ctx2 = FixedContext((b=4,), PrefixContext(@varname(a)))
169168
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn)
170169
@test new_vn == @varname(a.x[1])
171-
@test new_ctx == SamplingContext()
170+
@test new_ctx == FixedContext((b=4,))
172171

173172
ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,)))
174173
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn)
175174
@test new_vn == @varname(a.x[1])
176175
@test new_ctx == ConditionContext((a=1,))
177176

178-
ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,))))
177+
ctx4 = FixedContext((b=4,)PrefixContext(@varname(a), ConditionContext((a=1,))))
179178
new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn)
180179
@test new_vn == @varname(a.x[1])
181-
@test new_ctx == SamplingContext(ConditionContext((a=1,)))
180+
@test new_ctx == FixedContext((b=4,)ConditionContext((a=1,)))
182181
end
183182

184183
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
185184
prefix_vn = @varname(my_prefix)
186-
context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext())
187-
sampling_model = contextualize(model, context)
188-
# Sample with the context.
189-
varinfo = DynamicPPL.VarInfo()
190-
DynamicPPL.evaluate!!(sampling_model, varinfo)
185+
context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext())
186+
new_model = contextualize(model, context)
187+
# Initialize a new varinfo with the prefixed model
188+
DynamicPPL.init!!(new_model, DynamicPPL.VarInfo())
191189
# Extract the resulting varnames
192190
vns_actual = Set(keys(varinfo))
193191

@@ -202,22 +200,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
202200
end
203201
end
204202

205-
@testset "SamplingContext" begin
206-
context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext())
207-
@test context isa SamplingContext
208-
209-
# convenience constructors
210-
@test SamplingContext() == context
211-
@test SamplingContext(Random.default_rng()) == context
212-
@test SamplingContext(SampleFromPrior()) == context
213-
@test SamplingContext(DefaultContext()) == context
214-
@test SamplingContext(Random.default_rng(), SampleFromPrior()) == context
215-
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
216-
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
217-
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
218-
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
219-
end
220-
221203
@testset "ConditionContext" begin
222204
@testset "Nesting" begin
223205
@testset "NamedTuple" begin

test/ext/DynamicPPLJETExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,16 @@
6262

6363
@testset "demo models" begin
6464
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
65-
sampling_model = contextualize(model, SamplingContext(model.context))
6665
# Use debug logging below.
6766
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
68-
# Check that the inferred varinfo is indeed suitable for evaluation and sampling
67+
# Check that the inferred varinfo is indeed suitable for evaluation and initialisation
6968
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
7069
model, varinfo
7170
)
7271
JET.test_call(f_eval, argtypes_eval)
7372

7473
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
75-
sampling_model, varinfo
74+
init_model, varinfo
7675
)
7776
JET.test_call(f_sample, argtypes_sample)
7877
# For our demo models, they should all result in typed.
@@ -85,10 +84,6 @@
8584
model, typed_vi
8685
)
8786
JET.test_call(f_eval, argtypes_eval)
88-
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
89-
sampling_model, typed_vi
90-
)
91-
JET.test_call(f_sample, argtypes_sample)
9287
end
9388
end
9489
end

test/threadsafe.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@
6868
@time model(vi)
6969

7070
# Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements.
71-
sampling_model = contextualize(model, SamplingContext(model.context))
72-
DynamicPPL.evaluate_threadsafe!!(sampling_model, vi)
71+
DynamicPPL.evaluate_threadsafe!!(model, vi)
7372
@test getlogjoint(vi) lp_w_threads
7473
# check that it's wrapped during the model evaluation
7574
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
@@ -104,8 +103,7 @@
104103
@test lp_w_threads lp_wo_threads
105104

106105
# Ensure that we use `VarInfo`.
107-
sampling_model = contextualize(model, SamplingContext(model.context))
108-
DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi)
106+
DynamicPPL.evaluate_threadunsafe!!(model, vi)
109107
@test getlogjoint(vi) lp_w_threads
110108
@test vi_ isa VarInfo
111109
@test vi isa VarInfo

0 commit comments

Comments
 (0)