Skip to content

Commit bc74e8c

Browse files
committed
Don't include lhs of := in results of predict()
1 parent 6657441 commit bc74e8c

File tree

4 files changed

+91
-61
lines changed

4 files changed

+91
-61
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function DynamicPPL.predict(
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118118
model(rng, varinfo, DynamicPPL.SampleFromPrior())
119119

120-
vals = DynamicPPL.values_as_in_model(model, varinfo)
120+
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(
122122
collect,
123123
vcat,

src/values_as_in_model.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,26 @@ $(TYPEDFIELDS)
2222
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
2323
"values that are extracted from the model"
2424
values::T
25+
"whether to extract variables on the LHS of :="
26+
include_colon_eq::Bool
2527
"child context"
2628
context::C
2729
end
2830

29-
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
30-
function ValuesAsInModelContext(context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), context)
31+
# If child context is not passed
32+
ValuesAsInModelContext(values, include_colon_eq) = ValuesAsInModelContext(values, include_colon_eq, DefaultContext())
33+
# If values are not passed
34+
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
35+
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
3236
end
3337

3438
NodeTrait(::ValuesAsInModelContext) = IsParent()
3539
childcontext(context::ValuesAsInModelContext) = context.context
3640
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, child)
41+
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
3842
end
3943

40-
is_extracting_values(context::ValuesAsInModelContext) = true
44+
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
4145
function is_extracting_values(context::AbstractContext)
4246
return is_extracting_values(NodeTrait(context), context)
4347
end
@@ -114,8 +118,8 @@ function dot_tilde_assume(
114118
end
115119

116120
"""
117-
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
118-
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
121+
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
122+
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
119123
120124
Get the values of `varinfo` as they would be seen in the model.
121125
@@ -132,6 +136,7 @@ of additional model evaluations.
132136
133137
# Arguments
134138
- `model::Model`: model to extract realizations from.
139+
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
135140
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
136141
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
137142
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
@@ -183,24 +188,26 @@ false
183188
julia> # Approach 2: Extract realizations using `values_as_in_model`.
184189
# (✓) `values_as_in_model` will re-run the model and extract
185190
# the correct realization of `y` given the new values of `x`.
186-
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
191+
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
187192
true
188193
```
189194
"""
190195
function values_as_in_model(
191196
model::Model,
197+
include_colon_eq::Bool,
192198
varinfo::AbstractVarInfo=VarInfo(),
193199
context::AbstractContext=DefaultContext(),
194200
)
195-
context = ValuesAsInModelContext(context)
201+
context = ValuesAsInModelContext(include_colon_eq, context)
196202
evaluate!!(model, varinfo, context)
197203
return context.values
198204
end
199205
function values_as_in_model(
200206
rng::Random.AbstractRNG,
201207
model::Model,
208+
include_colon_eq::Bool,
202209
varinfo::AbstractVarInfo=VarInfo(),
203210
context::AbstractContext=DefaultContext(),
204211
)
205-
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
212+
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
206213
end

test/compiler.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,17 @@ module Issue537 end
702702
@test haskey(varinfo, @varname(x))
703703
@test !haskey(varinfo, @varname(y))
704704

705-
# While `values_as_in_model` should contain both `x` and `y`.
706-
values = values_as_in_model(model, deepcopy(varinfo))
705+
# While `values_as_in_model` should contain both `x` and `y`, if
706+
# include_colon_eq is set to `true`.
707+
values = values_as_in_model(model, true, deepcopy(varinfo))
707708
@test haskey(values, @varname(x))
708709
@test haskey(values, @varname(y))
710+
711+
# And if include_colon_eq is set to `false`, then `values` should
712+
# only contain `x`.
713+
values = values_as_in_model(model, false, deepcopy(varinfo))
714+
@test haskey(values, @varname(x))
715+
@test !haskey(values, @varname(y))
709716
end
710717
end
711718

test/model.jl

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
9090
samples = (; samples_dict...)
9191
samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl
9292
@test logpriors[i]
93-
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
93+
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
9494
@test loglikelihoods[i] DynamicPPL.TestUtils.loglikelihood_true(
9595
model, samples[:s], samples[:m]
9696
)
9797
@test logjoints[i]
98-
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
98+
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
9999
end
100100
end
101101
end
@@ -283,10 +283,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
283283
# Ensure log-probability computations are implemented.
284284
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
285285
@test loglikelihood(model, x)
286-
DynamicPPL.TestUtils.loglikelihood_true(model, x...)
286+
DynamicPPL.TestUtils.loglikelihood_true(model, x...)
287287
@test logjoint(model, x) DynamicPPL.TestUtils.logjoint_true(model, x...)
288288
@test logjoint(model, x) !=
289-
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
289+
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
290290
# Ensure `varnames` is implemented.
291291
vi = last(
292292
DynamicPPL.evaluate!!(
@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383383
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
384384
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
385385
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
386-
realizations = values_as_in_model(model, varinfo)
386+
# We can set the include_colon_eq arg to false because none of
387+
# the demo models contain :=. The behaviour when
388+
# include_colon_eq is true is tested in test/compiler.jl
389+
realizations = values_as_in_model(model, false, varinfo)
387390
# Ensure that all variables are found.
388391
vns_found = collect(keys(realizations))
389392
@test vns vns_found == vns vns_found
@@ -432,72 +435,85 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432435

433436
@testset "predict" begin
434437
@testset "with MCMCChains.Chains" begin
435-
DynamicPPL.Random.seed!(100)
436-
437438
@model function linear_reg(x, y, σ=0.1)
438439
β ~ Normal(0, 1)
439440
for i in eachindex(y)
440441
y[i] ~ Normal* x[i], σ)
441442
end
443+
# Insert a := block to test that it is not included in predictions
444+
σ2 := σ^2
442445
end
443446

444-
@model function linear_reg_vec(x, y, σ=0.1)
445-
β ~ Normal(0, 1)
446-
return y ~ MvNormal.* x, σ^2 * I)
447-
end
448-
447+
# Construct a chain with 'sampled values' of β
449448
ground_truth_β = 2
450449
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
451450

451+
# Generate predictions from that chain
452452
xs_test = [10 + 0.1, 10 + 2 * 0.1]
453453
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
454454
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)
455455

456-
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
457-
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
458-
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
459-
460-
# Ensure that `rng` is respected
461-
rng = MersenneTwister(42)
462-
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
463-
predictions2 = DynamicPPL.predict(
464-
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
465-
)
466-
@test all(Array(predictions1) .== Array(predictions2))
467-
468-
# Predict on two last indices for vectorized
469-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
470-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
471-
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
472-
473-
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
474-
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
456+
# Also test a vectorized model
457+
@model function linear_reg_vec(x, y, σ=0.1)
458+
β ~ Normal(0, 1)
459+
return y ~ MvNormal.* x, σ^2 * I)
460+
end
461+
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)
475462

476-
# Multiple chains
477-
multiple_β_chain = MCMCChains.Chains(
478-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
479-
)
480-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
481-
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
482-
@test size(multiple_β_chain, 3) == size(predictions, 3)
463+
@testset "variables in chain" begin
464+
# Note that this also checks that variables on the lhs of :=,
465+
# such as σ2, are not included in the resulting chain
466+
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
467+
end
483468

484-
for chain_idx in MCMCChains.chains(multiple_β_chain)
485-
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
469+
@testset "accuracy" begin
470+
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
486471
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
487472
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
488473
end
489474

490-
# Predict on two last indices for vectorized
491-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
492-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
493-
494-
for chain_idx in MCMCChains.chains(multiple_β_chain)
495-
ys_pred_vec = vec(
496-
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
475+
@testset "ensure that rng is respected" begin
476+
rng = MersenneTwister(42)
477+
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
478+
predictions2 = DynamicPPL.predict(
479+
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
497480
)
481+
@test all(Array(predictions1) .== Array(predictions2))
482+
end
483+
484+
@testset "accuracy on vectorized model" begin
485+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
486+
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
487+
498488
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
499489
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
500490
end
491+
492+
@testset "prediction from multiple chains" begin
493+
# Normal linreg model
494+
multiple_β_chain = MCMCChains.Chains(
495+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
496+
)
497+
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
498+
@test size(multiple_β_chain, 3) == size(predictions, 3)
499+
500+
for chain_idx in MCMCChains.chains(multiple_β_chain)
501+
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
502+
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
503+
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
504+
end
505+
506+
# Vectorized linreg model
507+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)
508+
509+
for chain_idx in MCMCChains.chains(multiple_β_chain)
510+
ys_pred_vec = vec(
511+
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
512+
)
513+
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
514+
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
515+
end
516+
end
501517
end
502518

503519
@testset "with AbstractVector{<:AbstractVarInfo}" begin
@@ -524,7 +540,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
524540

525541
@test size(predicted_vis) == size(chain)
526542
@test Set(keys(predicted_vis[1])) ==
527-
Set([@varname(β), @varname(y[1]), @varname(y[2])])
543+
Set([@varname(β), @varname(y[1]), @varname(y[2])])
528544
# because β samples are from the prior, the std will be larger
529545
@test mean([
530546
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)

0 commit comments

Comments
 (0)