Skip to content

Commit 9b38be4

Browse files
committed
Move most of test_utils into TestExt
1 parent 5bc980a commit 9b38be4

File tree

6 files changed

+57
-46
lines changed

6 files changed

+57
-46
lines changed

Project.toml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2222
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2323
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2424
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
25-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2625
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2726

2827
[weakdeps]
@@ -31,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3130
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3231
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3332
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
33+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3434
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3535

3636
[extensions]
@@ -39,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3939
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4040
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4141
DynamicPPLReverseDiffExt = ["ReverseDiff"]
42+
DynamicPPLTestExt = ["Test"]
4243
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4344

4445
[compat]
@@ -67,11 +68,3 @@ ReverseDiff = "1"
6768
Test = "1.6"
6869
ZygoteRules = "0.2"
6970
julia = "1.10"
70-
71-
[extras]
72-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
73-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
74-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
76-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
77-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLTestExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module DynamicPPLTestExt
2+
3+
using DynamicPPL
4+
using AbstractMCMC
5+
using Test
6+
7+
include("DynamicPPLTestExt/contexts.jl")
8+
include("DynamicPPLTestExt/varinfo.jl")
9+
include("DynamicPPLTestExt/sampler.jl")
10+
11+
end

src/test_utils/contexts.jl renamed to ext/DynamicPPLTestExt/contexts.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
Test that `context` implements the `AbstractContext` interface.
1010
"""
11-
function test_context_interface(context)
11+
function DynamicPPL.TestUtils.test_context_interface(context)
1212
# Is a subtype of `AbstractContext`.
1313
@test context isa DynamicPPL.AbstractContext
1414
# Should implement `NodeTrait.`
@@ -21,41 +21,33 @@ function test_context_interface(context)
2121
end
2222
end
2323

24-
"""
25-
Context that multiplies each log-prior by mod
26-
used to test whether varwise_logpriors respects child-context.
27-
"""
28-
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
29-
mod::T
30-
context::Ctx
31-
end
32-
function TestLogModifyingChildContext(
24+
function DynamicPPL.TestUtils.TestLogModifyingChildContext(
3325
mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext()
3426
)
35-
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
27+
return DynamicPPL.TestUtils.TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
3628
end
3729

38-
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
39-
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
40-
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
41-
return TestLogModifyingChildContext(context.mod, child)
30+
DynamicPPL.NodeTrait(::DynamicPPL.TestUtils.TestLogModifyingChildContext) = DynamicPPL.IsParent()
31+
DynamicPPL.childcontext(context::DynamicPPL.TestUtils.TestLogModifyingChildContext) = context.context
32+
function DynamicPPL.setchildcontext(context::DynamicPPL.TestUtils.TestLogModifyingChildContext, child)
33+
return DynamicPPL.TestUtils.TestLogModifyingChildContext(context.mod, child)
4234
end
43-
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
35+
function DynamicPPL.tilde_assume(context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, vn, vi)
4436
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
4537
return value, logp * context.mod, vi
4638
end
4739
function DynamicPPL.dot_tilde_assume(
48-
context::TestLogModifyingChildContext, right, left, vn, vi
40+
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vn, vi
4941
)
5042
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
5143
return value, logp * context.mod, vi
5244
end
53-
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
45+
function DynamicPPL.tilde_observe(context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi)
5446
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
5547
return logp * context.mod, vi
5648
end
5749
function DynamicPPL.dot_tilde_observe(
58-
context::TestLogModifyingChildContext, right, left, vi
50+
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi
5951
)
6052
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
6153
return logp * context.mod, vi

src/test_utils/sampler.jl renamed to ext/DynamicPPLTestExt/sampler.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
Return the mean of variable represented by `varname` in `chain`.
1010
"""
11-
marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)]))
11+
DynamicPPL.TestUtils.marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)]))
1212

1313
"""
1414
test_sampler(models, sampler, args...; kwargs...)
@@ -35,7 +35,7 @@ To change how comparison is done for a particular `chain` type, one can overload
3535
- `rtol=1e-3`: Relative tolerance used in `@test`.
3636
- `kwargs...`: Keyword arguments forwarded to `sample`.
3737
"""
38-
function test_sampler(
38+
function DynamicPPL.TestUtils.test_sampler(
3939
models,
4040
sampler::AbstractMCMC.AbstractSampler,
4141
args...;
@@ -51,7 +51,7 @@ function test_sampler(
5151
for vn in filter(varnames_filter, varnames(model))
5252
# We want to compare elementwise which can be achieved by
5353
# extracting the leaves of the `VarName` and the corresponding value.
54-
for vn_leaf in varname_leaves(vn, get(target_values, vn))
54+
for vn_leaf in DynamicPPL.varname_leaves(vn, get(target_values, vn))
5555
target_value = get(target_values, vn_leaf)
5656
chain_mean_value = marginal_mean_of_samples(chain, vn_leaf)
5757
@test chain_mean_value target_value atol = atol rtol = rtol
@@ -67,10 +67,10 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref).
6767
6868
This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`.
6969
"""
70-
function test_sampler_on_demo_models(
70+
function DynamicPPL.TestUtils.test_sampler_on_demo_models(
7171
sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
7272
)
73-
return test_sampler(DEMO_MODELS, sampler, args...; kwargs...)
73+
return test_sampler(DynamicPPL.TestUtils.DEMO_MODELS, sampler, args...; kwargs...)
7474
end
7575

7676
"""
@@ -80,6 +80,6 @@ Test that `sampler` produces the correct marginal posterior means on all models
8080
8181
As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref).
8282
"""
83-
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
83+
function DynamicPPL.TestUtils.test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
8484
return test_sampler_on_demo_models(sampler, args...; kwargs...)
8585
end

src/test_utils/varinfo.jl renamed to ext/DynamicPPLTestExt/varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
1010
"""
11-
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
11+
function DynamicPPL.TestUtils.test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
1212
for vn in vns
1313
@test compare(vi[vn], get(vals, vn); kwargs...)
1414
end
@@ -23,7 +23,7 @@ each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` i
2323
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
2424
of the varinfo instances.
2525
"""
26-
function setup_varinfos(
26+
function DynamicPPL.TestUtils.setup_varinfos(
2727
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
2828
)
2929
# VarInfo
@@ -58,7 +58,7 @@ function setup_varinfos(
5858
svi_vnv_ref,
5959
)) do vi
6060
# Set them all to the same values.
61-
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
61+
DynamicPPL.setlogp!!(DynamicPPL.update_values!!(vi, example_values, varnames), lp)
6262
end
6363

6464
if include_threadsafe

src/test_utils.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,37 @@
11
module TestUtils
22

3-
using AbstractMCMC
43
using DynamicPPL
54
using LinearAlgebra
65
using Distributions
7-
using Test
86

97
using Random: Random
108
using Bijectors: Bijectors
11-
using Accessors: Accessors
12-
13-
# For backwards compat.
14-
using DynamicPPL: varname_leaves, update_values!!
159

1610
include("test_utils/model_interface.jl")
1711
include("test_utils/models.jl")
18-
include("test_utils/contexts.jl")
19-
include("test_utils/varinfo.jl")
20-
include("test_utils/sampler.jl")
12+
13+
14+
##############################################################
15+
# The remainder of this file contains skeleton implementations for
16+
# DynamicPPLTestExt
17+
##############################################################
18+
19+
function test_context_interface end
20+
21+
"""
22+
Context that multiplies each log-prior by mod
23+
used to test whether varwise_logpriors respects child-context.
24+
"""
25+
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
26+
mod::T
27+
context::Ctx
28+
end
29+
30+
function marginal_mean_of_samples end
31+
function test_sampler end
32+
function test_sampler_on_demo_models end
33+
function test_sampler_continuous end
34+
function test_values end
35+
function setup_varinfos end
2136

2237
end

0 commit comments

Comments
 (0)