Skip to content

Commit 9d8d048

Browse files
committed
Move src/test_utils and test/test_util to DynamicPPLTestExt
1 parent ba490bf commit 9d8d048

21 files changed

+285
-309
lines changed

Project.toml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3333
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3536

3637
[extensions]
@@ -39,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3940
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4041
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4142
DynamicPPLReverseDiffExt = ["ReverseDiff"]
43+
DynamicPPLTestExt = ["Test"]
4244
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4345

4446
[compat]
@@ -67,11 +69,3 @@ ReverseDiff = "1"
6769
Test = "1.6"
6870
ZygoteRules = "0.2"
6971
julia = "1.10"
70-
71-
[extras]
72-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
73-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
74-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
75-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
76-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
77-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLTestExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module DynamicPPLTestExt
2+
3+
using DynamicPPL: DynamicPPL
4+
using Test: @test, @testset, @test_throws, @test_broken
5+
6+
include("DynamicPPLTestExt/utils.jl")
7+
8+
end

src/test_utils.jl renamed to ext/DynamicPPLTestExt/utils.jl

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
module TestUtils
1+
module TestExtUtils
2+
3+
###################################################
4+
# These used to be in DPPL/src/test_utils.jl ######
5+
###################################################
26

37
using AbstractMCMC
48
using DynamicPPL
@@ -1097,4 +1101,121 @@ function DynamicPPL.dot_tilde_observe(
10971101
return logp * context.mod, vi
10981102
end
10991103

1104+
###################################################
1105+
# These used to be in DPPL/test/test_util.jl ######
1106+
###################################################
1107+
1108+
# default model
1109+
@model function gdemo_d()
1110+
s ~ InverseGamma(2, 3)
1111+
m ~ Normal(0, sqrt(s))
1112+
1.5 ~ Normal(m, sqrt(s))
1113+
2.0 ~ Normal(m, sqrt(s))
1114+
return s, m
1115+
end
1116+
const gdemo_default = gdemo_d()
1117+
1118+
function test_model_ad(model, logp_manual)
1119+
vi = VarInfo(model)
1120+
x = DynamicPPL.getall(vi)
1121+
1122+
# Log probabilities using the model.
1123+
= DynamicPPL.LogDensityFunction(model, vi)
1124+
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)
1125+
1126+
# Check that both functions return the same values.
1127+
lp = logp_manual(x)
1128+
@test logp_model(x) lp
1129+
1130+
# Gradients based on the manual implementation.
1131+
grad = ForwardDiff.gradient(logp_manual, x)
1132+
1133+
y, back = Tracker.forward(logp_manual, x)
1134+
@test Tracker.data(y) lp
1135+
@test Tracker.data(back(1)[1]) grad
1136+
1137+
y, back = Zygote.pullback(logp_manual, x)
1138+
@test y lp
1139+
@test back(1)[1] grad
1140+
1141+
# Gradients based on the model.
1142+
@test ForwardDiff.gradient(logp_model, x) grad
1143+
1144+
y, back = Tracker.forward(logp_model, x)
1145+
@test Tracker.data(y) lp
1146+
@test Tracker.data(back(1)[1]) grad
1147+
1148+
y, back = Zygote.pullback(logp_model, x)
1149+
@test y lp
1150+
@test back(1)[1] grad
1151+
end
1152+
1153+
"""
1154+
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
1155+
1156+
Test `setval!` on `model` and `chain`.
1157+
1158+
Worth noting that this only supports models containing symbols of the forms
1159+
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
1160+
"""
1161+
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
1162+
var_info = VarInfo(model)
1163+
spl = SampleFromPrior()
1164+
θ_old = var_info[spl]
1165+
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
1166+
θ_new = var_info[spl]
1167+
@test θ_old != θ_new
1168+
vals = DynamicPPL.values_as(var_info, OrderedDict)
1169+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
1170+
for (n, v) in mapreduce(collect, vcat, iters)
1171+
n = string(n)
1172+
if Symbol(n) keys(chain)
1173+
# Assume it's a group
1174+
chain_val = vec(
1175+
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
1176+
)
1177+
v_true = vec(v)
1178+
else
1179+
chain_val = chain[sample_idx, n, chain_idx]
1180+
v_true = v
1181+
end
1182+
1183+
@test v_true == chain_val
1184+
end
11001185
end
1186+
1187+
"""
1188+
short_varinfo_name(vi::AbstractVarInfo)
1189+
1190+
Return string representing a short description of `vi`.
1191+
"""
1192+
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
1193+
"threadsafe($(short_varinfo_name(vi.varinfo)))"
1194+
function short_varinfo_name(vi::TypedVarInfo)
1195+
DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
1196+
return "TypedVarInfo"
1197+
end
1198+
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
1199+
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
1200+
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
1201+
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
1202+
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
1203+
return "SimpleVarInfo{<:VarNamedVector}"
1204+
end
1205+
1206+
# convenient functions for testing model.jl
1207+
# function to modify the representation of values based on their length
1208+
function modify_value_representation(nt::NamedTuple)
1209+
modified_nt = NamedTuple()
1210+
for (key, value) in zip(keys(nt), values(nt))
1211+
if length(value) == 1 # Scalar value
1212+
modified_value = value[1]
1213+
else # Non-scalar value
1214+
modified_value = value
1215+
end
1216+
modified_nt = merge(modified_nt, (key => modified_value,))
1217+
end
1218+
return modified_nt
1219+
end
1220+
1221+
end # module TestExtUtils

test/ad.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
@testset "AD: ForwardDiff and ReverseDiff" begin
2-
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
2+
@testset "$(m.f)" for m in TU.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
4-
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
5-
vns = DynamicPPL.TestUtils.varnames(m)
6-
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
4+
rand_param_values = TU.rand_prior_true(m)
5+
vns = TU.varnames(m)
6+
varinfos = TU.setup_varinfos(m, rand_param_values, vns)
77

8-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
8+
@testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos
99
f = DynamicPPL.LogDensityFunction(m, varinfo)
1010

1111
# use ForwardDiff result as reference

test/compat/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logpdf(dist, 2.0)
1313
end
1414

15-
test_model_ad(gdemo_default, logp_gdemo_default)
15+
TU.test_model_ad(TU.gdemo_default, logp_gdemo_default)
1616

1717
@model function wishart_ad()
1818
return v ~ Wishart(7, [1 0.5; 0.5 1])
@@ -24,7 +24,7 @@
2424
return logpdf(dist, reshape(x, 2, 2))
2525
end
2626

27-
test_model_ad(wishart_ad(), logp_wishart_ad)
27+
TU.test_model_ad(wishart_ad(), logp_wishart_ad)
2828
end
2929

3030
# https://github.com/TuringLang/Turing.jl/issues/1595

test/contexts.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ end
167167
vn_without_prefix = remove_prefix(vn)
168168

169169
# Let's check elementwise.
170-
for vn_child in
171-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
170+
for vn_child in TU.varname_leaves(vn_without_prefix, val)
172171
if getoptic(vn_child)(val) === missing
173172
@test contextual_isassumption(context, vn_child)
174173
else
@@ -200,8 +199,7 @@ end
200199
# `ConditionContext` with the conditioned variable.
201200
vn_without_prefix = remove_prefix(vn)
202201

203-
for vn_child in
204-
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
202+
for vn_child in TU.varname_leaves(vn_without_prefix, val)
205203
# `vn_child` should be in `context`.
206204
@test hasconditioned_nested(context, vn_child)
207205
# Value should be the same as extracted above.
@@ -216,7 +214,7 @@ end
216214
@testset "Evaluation" begin
217215
@testset "$context" for context in contexts
218216
# Just making sure that we can actually sample with each of the contexts.
219-
@test (gdemo_default(SamplingContext(context)); true)
217+
@test (TU.gdemo_default(SamplingContext(context)); true)
220218
end
221219
end
222220

@@ -258,7 +256,7 @@ end
258256
end
259257

260258
@testset "FixedContext" begin
261-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
259+
@testset "$(model.f)" for model in TU.DEMO_MODELS
262260
retval = model()
263261
s, m = retval.s, retval.m
264262

test/debug_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
@testset "check_model" begin
22
@testset "context interface" begin
33
# HACK: Require a model to instantiate it, so let's just grab one.
4-
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
4+
model = first(TU.DEMO_MODELS)
55
context = DynamicPPL.DebugUtils.DebugContext(model)
6-
DynamicPPL.TestUtils.test_context_interface(context)
6+
TU.test_context_interface(context)
77
end
88

9-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
9+
@testset "$(model.f)" for model in TU.DEMO_MODELS
1010
issuccess, trace = check_model_and_trace(model)
1111
# These models should all work.
1212
@test issuccess
1313

1414
# Check that the trace contains all the variables in the model.
1515
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
16-
for vn in DynamicPPL.TestUtils.varnames(model)
16+
for vn in TU.varnames(model)
1717
@test vn in varnames_in_trace
1818
end
1919

@@ -156,7 +156,7 @@
156156
end
157157

158158
@testset "comparing multiple traces" begin
159-
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
159+
model = TU.demo_dynamic_constraint()
160160
issuccess_1, trace_1 = check_model_and_trace(model)
161161
issuccess_2, trace_2 = check_model_and_trace(model)
162162
@test issuccess_1 && issuccess_2

test/linking.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ end
7575
model = demo()
7676

7777
example_values = rand(NamedTuple, model)
78-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
79-
@testset "$(short_varinfo_name(vi))" for vi in vis
78+
vis = TU.setup_varinfos(model, example_values, (@varname(m),))
79+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
8080
# Evaluate once to ensure we have `logp` value.
8181
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
8282
vi_linked = if mutable
@@ -109,10 +109,8 @@ end
109109
model = demo_lkj(d)
110110
dist = LKJCholesky(d, 1.0, uplo)
111111
values_original = rand(NamedTuple, model)
112-
vis = DynamicPPL.TestUtils.setup_varinfos(
113-
model, values_original, (@varname(x),)
114-
)
115-
@testset "$(short_varinfo_name(vi))" for vi in vis
112+
vis = TU.setup_varinfos(model, values_original, (@varname(x),))
113+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
116114
val = vi[@varname(x), dist]
117115
# Ensure that `reconstruct` works as intended.
118116
@test val isa Cholesky
@@ -150,8 +148,8 @@ end
150148
@testset "d=$d" for d in [2, 3, 5]
151149
model = demo_dirichlet(d)
152150
example_values = rand(NamedTuple, model)
153-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
154-
@testset "$(short_varinfo_name(vi))" for vi in vis
151+
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
152+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
155153
lp = logpdf(Dirichlet(d, 1.0), vi[:])
156154
@test length(vi[:]) == d
157155
lp_model = logjoint(model, vi)
@@ -189,8 +187,8 @@ end
189187
]
190188
model = demo_highdim_dirichlet(ns...)
191189
example_values = rand(NamedTuple, model)
192-
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
193-
@testset "$(short_varinfo_name(vi))" for vi in vis
190+
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
191+
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
194192
# Linked.
195193
vi_linked = if mutable
196194
DynamicPPL.link!!(deepcopy(vi), model)

test/logdensityfunction.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
22

33
@testset "`getmodel` and `setmodel`" begin
4-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
5-
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
4+
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
5+
model = TU.DEMO_MODELS[1]
66
= DynamicPPL.LogDensityFunction(model)
77
@test DynamicPPL.getmodel(ℓ) == model
88
@test DynamicPPL.setmodel(ℓ, model).model == model
@@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever
2121
end
2222

2323
@testset "LogDensityFunction" begin
24-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
25-
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
26-
vns = DynamicPPL.TestUtils.varnames(model)
27-
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
24+
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
25+
example_values = TU.rand_prior_true(model)
26+
vns = TU.varnames(model)
27+
varinfos = TU.setup_varinfos(model, example_values, vns)
2828

2929
@testset "$(varinfo)" for varinfo in varinfos
3030
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)

0 commit comments

Comments
 (0)