Skip to content

Commit c0fd677

Browse files
committed
Split TestUtils module up
1 parent ba490bf commit c0fd677

File tree

7 files changed

+407
-381
lines changed

7 files changed

+407
-381
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ include("context_implementations.jl")
186186
include("compiler.jl")
187187
include("pointwise_logdensities.jl")
188188
include("submodel_macro.jl")
189-
include("test_utils.jl")
189+
include("test_utils/main.jl")
190190
include("transforming.jl")
191191
include("logdensityfunction.jl")
192192
include("model_utils.jl")

src/test_utils/contexts.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# contexts.jl
2+
# -----------
3+
#
4+
# Utilities for testing contexts.
5+
6+
"""
7+
test_context_interface(context)
8+
9+
Test that `context` implements the `AbstractContext` interface.
10+
"""
11+
function test_context_interface(context)
12+
# Is a subtype of `AbstractContext`.
13+
@test context isa DynamicPPL.AbstractContext
14+
# Should implement `NodeTrait.`
15+
@test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf}
16+
# If it's a parent.
17+
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent
18+
# Should implement `childcontext` and `setchildcontext`
19+
@test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) ==
20+
context
21+
end
22+
end
23+
24+
"""
25+
Context that multiplies each log-prior by mod
26+
used to test whether varwise_logpriors respects child-context.
27+
"""
28+
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
29+
mod::T
30+
context::Ctx
31+
end
32+
function TestLogModifyingChildContext(
33+
mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext()
34+
)
35+
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
36+
end
37+
38+
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
39+
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
40+
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
41+
return TestLogModifyingChildContext(context.mod, child)
42+
end
43+
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
44+
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
45+
return value, logp * context.mod, vi
46+
end
47+
function DynamicPPL.dot_tilde_assume(
48+
context::TestLogModifyingChildContext, right, left, vn, vi
49+
)
50+
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
51+
return value, logp * context.mod, vi
52+
end
53+
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
54+
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
55+
return logp * context.mod, vi
56+
end
57+
function DynamicPPL.dot_tilde_observe(
58+
context::TestLogModifyingChildContext, right, left, vi
59+
)
60+
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
61+
return logp * context.mod, vi
62+
end

src/test_utils/main.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module TestUtils
2+
3+
using AbstractMCMC
4+
using DynamicPPL
5+
using LinearAlgebra
6+
using Distributions
7+
using Test
8+
9+
using Random: Random
10+
using Bijectors: Bijectors
11+
using Accessors: Accessors
12+
13+
# For backwards compat.
14+
using DynamicPPL: varname_leaves, update_values!!
15+
16+
include("model_interface.jl")
17+
include("models.jl")
18+
include("contexts.jl")
19+
include("varinfo.jl")
20+
include("sampler.jl")
21+
22+
end

src/test_utils/model_interface.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# model_interface.jl
2+
# ------------------
3+
#
4+
# This file contains the functions that the models inside models.jl should
5+
# implement.
6+
7+
"""
8+
logprior_true(model, args...)
9+
10+
Return the `logprior` of `model` for `args`.
11+
12+
This should generally be implemented by hand for every specific `model`.
13+
14+
See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref).
15+
"""
16+
function logprior_true end
17+
18+
"""
19+
loglikelihood_true(model, args...)
20+
21+
Return the `loglikelihood` of `model` for `args`.
22+
23+
This should generally be implemented by hand for every specific `model`.
24+
25+
See also: [`logjoint_true`](@ref), [`logprior_true`](@ref).
26+
"""
27+
function loglikelihood_true end
28+
29+
"""
30+
logjoint_true(model, args...)
31+
32+
Return the `logjoint` of `model` for `args`.
33+
34+
Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`.
35+
36+
This should generally be implemented by hand for every specific `model`
37+
so that the returned value can be used as a ground-truth for testing things like:
38+
39+
1. Validity of evaluation of `model` using a particular implementation of `AbstractVarInfo`.
40+
2. Validity of a sampler when combined with DynamicPPL by running the sampler twice: once targeting ground-truth functions, e.g. `logjoint_true`, and once targeting `model`.
41+
42+
And more.
43+
44+
See also: [`logprior_true`](@ref), [`loglikelihood_true`](@ref).
45+
"""
46+
function logjoint_true(model::Model, args...)
47+
return logprior_true(model, args...) + loglikelihood_true(model, args...)
48+
end
49+
50+
"""
51+
logjoint_true_with_logabsdet_jacobian(model::Model, args...)
52+
53+
Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`.
54+
55+
Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the
56+
log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables.
57+
58+
Note that `args` are assumed be in the support of `model`, while `args_unconstrained`
59+
will be unconstrained.
60+
61+
This should generally not be implemented directly, instead one should implement
62+
[`logprior_true_with_logabsdet_jacobian`](@ref) for a given `model`.
63+
64+
See also: [`logjoint_true`](@ref), [`logprior_true_with_logabsdet_jacobian`](@ref).
65+
"""
66+
function logjoint_true_with_logabsdet_jacobian(model::Model, args...)
67+
args_unconstrained, lp = logprior_true_with_logabsdet_jacobian(model, args...)
68+
return args_unconstrained, lp + loglikelihood_true(model, args...)
69+
end
70+
71+
"""
72+
logprior_true_with_logabsdet_jacobian(model::Model, args...)
73+
74+
Return a tuple `(args_unconstrained, logprior_unconstrained)` of `model` for `args...`.
75+
76+
Unlike [`logprior_true`](@ref), the returned logprior computation includes the
77+
log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables.
78+
79+
Note that `args` are assumed be in the support of `model`, while `args_unconstrained`
80+
will be unconstrained.
81+
82+
See also: [`logprior_true`](@ref).
83+
"""
84+
function logprior_true_with_logabsdet_jacobian end
85+
86+
"""
87+
varnames(model::Model)
88+
89+
Return a collection of `VarName` as they are expected to appear in the model.
90+
91+
Even though it is recommended to implement this by hand for a particular `Model`,
92+
a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided.
93+
"""
94+
function varnames(model::Model)
95+
return collect(
96+
keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext())))
97+
)
98+
end
99+
100+
"""
101+
posterior_mean(model::Model)
102+
103+
Return a `NamedTuple` compatible with `varnames(model)` where the values represent
104+
the posterior mean under `model`.
105+
106+
"Compatible" means that a `varname` from `varnames(model)` can be used to extract the
107+
corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
108+
"""
109+
function posterior_mean end
110+
111+
"""
112+
rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model)
113+
114+
Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`.
115+
"""
116+
function rand_prior_true(model::DynamicPPL.Model)
117+
return rand_prior_true(Random.default_rng(), model)
118+
end

0 commit comments

Comments
 (0)