Skip to content

Commit 2e8adf4

Browse files
Fix for rand + replace overloads of rand with rand_prior_true for testing models (#541)
* preserve context from model in `rand` * replace rand overloads in TestUtils with definitions of rand_prior_true so we can properly test rand * removed NamedTuple from signature of TestUtils.rand_prior_true * updated references to previous overloads of rand to now use rand_prior_true * test rand for DEMO_MODELS * formatting * fixed tests for rand for DEMO_MODELS * fixed linkning tests * added missing impl of rand_prior_true for demo_static_transformation * formatting * fixed rand_prior_true for demo_static_transformation * bump minor version as this will be breaking * bump patch version * fixed old usage of rand * Update test/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed another usage of rand --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 12e7c27 commit 2e8adf4

File tree

9 files changed

+67
-47
lines changed

9 files changed

+67
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.20"
3+
version = "0.23.21"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
10371037
evaluate!!(
10381038
model,
10391039
SimpleVarInfo{Float64}(OrderedDict()),
1040-
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
1040+
SamplingContext(rng, SampleFromPrior(), model.context),
10411041
),
10421042
)
10431043
return values_as(x, T)

src/test_utils.jl

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
179179
"""
180180
function posterior_mean end
181181

182+
"""
183+
rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model)
184+
185+
Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`.
186+
"""
187+
function rand_prior_true(model::DynamicPPL.Model)
188+
return rand_prior_true(Random.default_rng(), model)
189+
end
190+
182191
"""
183192
demo_dynamic_constraint()
184193
@@ -263,10 +272,8 @@ function logprior_true_with_logabsdet_jacobian(
263272
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
264273
end
265274

266-
function Random.rand(
267-
rng::Random.AbstractRNG,
268-
::Type{NamedTuple},
269-
model::Model{typeof(demo_one_variable_multiple_constraints)},
275+
function rand_prior_true(
276+
rng::Random.AbstractRNG, model::Model{typeof(demo_one_variable_multiple_constraints)}
270277
)
271278
x = Vector{Float64}(undef, 5)
272279
x[1] = rand(rng, Normal())
@@ -310,9 +317,7 @@ function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)
310317
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
311318
end
312319

313-
function Random.rand(
314-
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
315-
)
320+
function rand_prior_true(rng::Random.AbstractRNG, model::Model{typeof(demo_lkjchol)})
316321
x = rand(rng, LKJCholesky(model.args.d, 1.0))
317322
return (x=x,)
318323
end
@@ -724,12 +729,6 @@ const DemoModels = Union{
724729
Model{typeof(demo_assume_matrix_dot_observe_matrix)},
725730
}
726731

727-
# We require demo models to have explict impleentations of `rand` since we want
728-
# these to be considered as ground truth.
729-
function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels)
730-
return error("demo models requires explicit implementation of rand")
731-
end
732-
733732
const UnivariateAssumeDemoModels = Union{
734733
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
735734
}
@@ -743,9 +742,7 @@ function posterior_optima(::UnivariateAssumeDemoModels)
743742
# TODO: Figure out exact for `s`.
744743
return (s=0.907407, m=7 / 6)
745744
end
746-
function Random.rand(
747-
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels
748-
)
745+
function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels)
749746
s = rand(rng, InverseGamma(2, 3))
750747
m = rand(rng, Normal(0, sqrt(s)))
751748

@@ -766,7 +763,7 @@ const MultivariateAssumeDemoModels = Union{
766763
}
767764
function posterior_mean(model::MultivariateAssumeDemoModels)
768765
# Get some containers to fill.
769-
vals = Random.rand(model)
766+
vals = rand_prior_true(model)
770767

771768
vals.s[1] = 19 / 8
772769
vals.m[1] = 3 / 4
@@ -778,7 +775,7 @@ function posterior_mean(model::MultivariateAssumeDemoModels)
778775
end
779776
function likelihood_optima(model::MultivariateAssumeDemoModels)
780777
# Get some containers to fill.
781-
vals = Random.rand(model)
778+
vals = rand_prior_true(model)
782779

783780
# NOTE: These are "as close to zero as we can get".
784781
vals.s[1] = 1e-32
@@ -791,7 +788,7 @@ function likelihood_optima(model::MultivariateAssumeDemoModels)
791788
end
792789
function posterior_optima(model::MultivariateAssumeDemoModels)
793790
# Get some containers to fill.
794-
vals = Random.rand(model)
791+
vals = rand_prior_true(model)
795792

796793
# TODO: Figure out exact for `s[1]`.
797794
vals.s[1] = 0.890625
@@ -801,9 +798,7 @@ function posterior_optima(model::MultivariateAssumeDemoModels)
801798

802799
return vals
803800
end
804-
function Random.rand(
805-
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels
806-
)
801+
function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels)
807802
# Get template values from `model`.
808803
retval = model(rng)
809804
vals = (s=retval.s, m=retval.m)
@@ -821,7 +816,7 @@ const MatrixvariateAssumeDemoModels = Union{
821816
}
822817
function posterior_mean(model::MatrixvariateAssumeDemoModels)
823818
# Get some containers to fill.
824-
vals = Random.rand(model)
819+
vals = rand_prior_true(model)
825820

826821
vals.s[1, 1] = 19 / 8
827822
vals.m[1] = 3 / 4
@@ -833,7 +828,7 @@ function posterior_mean(model::MatrixvariateAssumeDemoModels)
833828
end
834829
function likelihood_optima(model::MatrixvariateAssumeDemoModels)
835830
# Get some containers to fill.
836-
vals = Random.rand(model)
831+
vals = rand_prior_true(model)
837832

838833
# NOTE: These are "as close to zero as we can get".
839834
vals.s[1, 1] = 1e-32
@@ -846,7 +841,7 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels)
846841
end
847842
function posterior_optima(model::MatrixvariateAssumeDemoModels)
848843
# Get some containers to fill.
849-
vals = Random.rand(model)
844+
vals = rand_prior_true(model)
850845

851846
# TODO: Figure out exact for `s[1]`.
852847
vals.s[1, 1] = 0.890625
@@ -856,9 +851,7 @@ function posterior_optima(model::MatrixvariateAssumeDemoModels)
856851

857852
return vals
858853
end
859-
function Base.rand(
860-
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MatrixvariateAssumeDemoModels
861-
)
854+
function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemoModels)
862855
# Get template values from `model`.
863856
retval = model(rng)
864857
vals = (s=retval.s, m=retval.m)
@@ -954,6 +947,14 @@ function logprior_true_with_logabsdet_jacobian(
954947
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
955948
end
956949

950+
function rand_prior_true(
951+
rng::Random.AbstractRNG, model::Model{typeof(demo_static_transformation)}
952+
)
953+
s = rand(rng, InverseGamma(2, 3))
954+
m = rand(rng, Normal(0, sqrt(s)))
955+
return (s=s, m=m)
956+
end
957+
957958
"""
958959
marginal_mean_of_samples(chain, varname)
959960

test/linking.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ end
7272
@model demo() = m ~ dist
7373
model = demo()
7474

75-
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
75+
example_values = rand(NamedTuple, model)
76+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
7677
@testset "$(short_varinfo_name(vi))" for vi in vis
7778
# Evaluate once to ensure we have `logp` value.
7879
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
@@ -105,7 +106,7 @@ end
105106
@testset "d=$d" for d in [2, 3, 5]
106107
model = demo_lkj(d)
107108
dist = LKJCholesky(d, 1.0, uplo)
108-
values_original = rand(model)
109+
values_original = rand(NamedTuple, model)
109110
vis = DynamicPPL.TestUtils.setup_varinfos(
110111
model, values_original, (@varname(x),)
111112
)
@@ -146,7 +147,8 @@ end
146147
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
147148
@testset "d=$d" for d in [2, 3, 5]
148149
model = demo_dirichlet(d)
149-
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
150+
example_values = rand(NamedTuple, model)
151+
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
150152
@testset "$(short_varinfo_name(vi))" for vi in vis
151153
lp = logpdf(Dirichlet(d, 1.0), vi[:])
152154
@test length(vi[:]) == d

test/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Test, DynamicPPL, LogDensityProblems
22

33
@testset "LogDensityFunction" begin
44
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
5-
example_values = rand(NamedTuple, model)
5+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
66
vns = DynamicPPL.TestUtils.varnames(model)
77
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
88

test/loglikelihoods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "loglikelihoods.jl" begin
22
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
3-
example_values = rand(NamedTuple, m)
3+
example_values = DynamicPPL.TestUtils.rand_prior_true(m)
44

55
# Instantiate a `VarInfo` with the example values.
66
vi = VarInfo(m)

test/model.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
222222
Random.seed!(1776)
223223
s, m = model()
224224
sample_namedtuple = (; s=s, m=m)
225-
sample_dict = Dict(@varname(s) => s, @varname(m) => m)
225+
sample_dict = OrderedDict(@varname(s) => s, @varname(m) => m)
226226

227227
# With explicit RNG
228228
@test rand(Random.seed!(1776), model) == sample_namedtuple
@@ -235,7 +235,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
235235
Random.seed!(1776)
236236
@test rand(NamedTuple, model) == sample_namedtuple
237237
Random.seed!(1776)
238-
@test rand(Dict, model) == sample_dict
238+
@test rand(OrderedDict, model) == sample_dict
239239
end
240240

241241
@testset "default arguments" begin
@@ -263,7 +263,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
263263

264264
@testset "TestUtils" begin
265265
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
266-
x = rand(model)
266+
x = DynamicPPL.TestUtils.rand_prior_true(model)
267+
# `rand_prior_true` should return a `NamedTuple`.
268+
@test x isa NamedTuple
269+
270+
# `rand` with a `AbstractDict` should have `varnames` as keys.
271+
x_rand_dict = rand(OrderedDict, model)
272+
for vn in DynamicPPL.TestUtils.varnames(model)
273+
@test haskey(x_rand_dict, vn)
274+
end
275+
# `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys.
276+
x_rand_nt = rand(NamedTuple, model)
277+
for vn in DynamicPPL.TestUtils.varnames(model)
278+
@test haskey(x_rand_nt, Symbol(vn))
279+
end
280+
267281
# Ensure log-probability computations are implemented.
268282
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
269283
@test loglikelihood(model, x)

test/simple_varinfo.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
@testset "link!! & invlink!! on $(nameof(model))" for model in
6262
DynamicPPL.TestUtils.DEMO_MODELS
63-
values_constrained = rand(NamedTuple, model)
63+
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
6464
@testset "$(typeof(vi))" for vi in (
6565
SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model)
6666
)
@@ -112,7 +112,7 @@
112112

113113
# We might need to pre-allocate for the variable `m`, so we need
114114
# to see whether this is the case.
115-
svi_nt = SimpleVarInfo(rand(NamedTuple, model))
115+
svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model))
116116
svi_dict = SimpleVarInfo(VarInfo(model), Dict)
117117

118118
@testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in (
@@ -121,7 +121,7 @@
121121
DynamicPPL.settrans!!(svi_nt, true),
122122
DynamicPPL.settrans!!(svi_dict, true),
123123
)
124-
# Random seed is set in each `@testset`, so we need to sample
124+
# RandOM seed is set in each `@testset`, so we need to sample
125125
# a new realization for `m` here.
126126
retval = model()
127127

@@ -138,7 +138,7 @@
138138
@test getlogp(svi_new) != 0
139139

140140
### Evaluation ###
141-
values_eval_constrained = rand(NamedTuple, model)
141+
values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
142142
if DynamicPPL.istrans(svi)
143143
_values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian(
144144
model, values_eval_constrained...
@@ -225,7 +225,7 @@
225225
model = DynamicPPL.TestUtils.demo_static_transformation()
226226

227227
varinfos = DynamicPPL.TestUtils.setup_varinfos(
228-
model, rand(NamedTuple, model), [@varname(s), @varname(m)]
228+
model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)]
229229
)
230230
@testset "$(short_varinfo_name(vi))" for vi in varinfos
231231
# Initialize varinfo and link.

test/varinfo.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
338338

339339
@testset "values_as" begin
340340
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
341-
example_values = rand(NamedTuple, model)
341+
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
342342
vns = DynamicPPL.TestUtils.varnames(model)
343343

344344
# Set up the different instances of `AbstractVarInfo` with the desired values.
@@ -385,7 +385,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
385385
DynamicPPL.TestUtils.demo_lkjchol(),
386386
]
387387
@testset "mutating=$mutating" for mutating in [false, true]
388-
value_true = rand(model)
388+
value_true = DynamicPPL.TestUtils.rand_prior_true(model)
389389
varnames = DynamicPPL.TestUtils.varnames(model)
390390
varinfos = DynamicPPL.TestUtils.setup_varinfos(
391391
model, value_true, varnames; include_threadsafe=true
@@ -541,7 +541,10 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
541541
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
542542
vns = DynamicPPL.TestUtils.varnames(model)
543543
varinfos = DynamicPPL.TestUtils.setup_varinfos(
544-
model, rand(model), vns; include_threadsafe=true
544+
model,
545+
DynamicPPL.TestUtils.rand_prior_true(model),
546+
vns;
547+
include_threadsafe=true,
545548
)
546549
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
547550
@testset "with itself" begin
@@ -581,7 +584,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
581584
end
582585

583586
@testset "with different value" begin
584-
x = DynamicPPL.TestUtils.rand(model)
587+
x = DynamicPPL.TestUtils.rand_prior_true(model)
585588
varinfo_changed = DynamicPPL.TestUtils.update_values!!(
586589
deepcopy(varinfo), x, vns
587590
)

0 commit comments

Comments
 (0)