Skip to content

Commit a8922bd

Browse files
committed
Don't include lhs of := in results of predict()
1 parent 6657441 commit a8922bd

File tree

4 files changed

+90
-56
lines changed

4 files changed

+90
-56
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function DynamicPPL.predict(
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118118
model(rng, varinfo, DynamicPPL.SampleFromPrior())
119119

120-
vals = DynamicPPL.values_as_in_model(model, varinfo)
120+
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(
122122
collect,
123123
vcat,

src/values_as_in_model.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,28 @@ $(TYPEDFIELDS)
2222
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
2323
"values that are extracted from the model"
2424
values::T
25+
"whether to extract variables on the LHS of :="
26+
include_colon_eq::Bool
2527
"child context"
2628
context::C
2729
end
2830

29-
ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
30-
function ValuesAsInModelContext(context::AbstractContext)
31-
return ValuesAsInModelContext(OrderedDict(), context)
31+
# If child context is not passed
32+
function ValuesAsInModelContext(values, include_colon_eq)
33+
return ValuesAsInModelContext(values, include_colon_eq, DefaultContext())
34+
end
35+
# If values are not passed
36+
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
37+
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
3238
end
3339

3440
NodeTrait(::ValuesAsInModelContext) = IsParent()
3541
childcontext(context::ValuesAsInModelContext) = context.context
3642
function setchildcontext(context::ValuesAsInModelContext, child)
37-
return ValuesAsInModelContext(context.values, child)
43+
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
3844
end
3945

40-
is_extracting_values(context::ValuesAsInModelContext) = true
46+
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
4147
function is_extracting_values(context::AbstractContext)
4248
return is_extracting_values(NodeTrait(context), context)
4349
end
@@ -114,8 +120,8 @@ function dot_tilde_assume(
114120
end
115121

116122
"""
117-
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
118-
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
123+
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
124+
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
119125
120126
Get the values of `varinfo` as they would be seen in the model.
121127
@@ -132,6 +138,7 @@ of additional model evaluations.
132138
133139
# Arguments
134140
- `model::Model`: model to extract realizations from.
141+
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
135142
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
136143
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
137144
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
@@ -183,24 +190,26 @@ false
183190
julia> # Approach 2: Extract realizations using `values_as_in_model`.
184191
# (✓) `values_as_in_model` will re-run the model and extract
185192
# the correct realization of `y` given the new values of `x`.
186-
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
193+
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
187194
true
188195
```
189196
"""
190197
function values_as_in_model(
191198
model::Model,
199+
include_colon_eq::Bool,
192200
varinfo::AbstractVarInfo=VarInfo(),
193201
context::AbstractContext=DefaultContext(),
194202
)
195-
context = ValuesAsInModelContext(context)
203+
context = ValuesAsInModelContext(include_colon_eq, context)
196204
evaluate!!(model, varinfo, context)
197205
return context.values
198206
end
199207
function values_as_in_model(
200208
rng::Random.AbstractRNG,
201209
model::Model,
210+
include_colon_eq::Bool,
202211
varinfo::AbstractVarInfo=VarInfo(),
203212
context::AbstractContext=DefaultContext(),
204213
)
205-
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
214+
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
206215
end

test/compiler.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,17 @@ module Issue537 end
702702
@test haskey(varinfo, @varname(x))
703703
@test !haskey(varinfo, @varname(y))
704704

705-
# While `values_as_in_model` should contain both `x` and `y`.
706-
values = values_as_in_model(model, deepcopy(varinfo))
705+
# While `values_as_in_model` should contain both `x` and `y`, if
706+
# include_colon_eq is set to `true`.
707+
values = values_as_in_model(model, true, deepcopy(varinfo))
707708
@test haskey(values, @varname(x))
708709
@test haskey(values, @varname(y))
710+
711+
# And if include_colon_eq is set to `false`, then `values` should
712+
# only contain `x`.
713+
values = values_as_in_model(model, false, deepcopy(varinfo))
714+
@test haskey(values, @varname(x))
715+
@test !haskey(values, @varname(y))
709716
end
710717
end
711718

test/model.jl

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383383
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
384384
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
385385
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
386-
realizations = values_as_in_model(model, varinfo)
386+
# We can set the include_colon_eq arg to false because none of
387+
# the demo models contain :=. The behaviour when
388+
# include_colon_eq is true is tested in test/compiler.jl
389+
realizations = values_as_in_model(model, false, varinfo)
387390
# Ensure that all variables are found.
388391
vns_found = collect(keys(realizations))
389392
@test vns vns_found == vns vns_found
@@ -432,72 +435,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432435

433436
@testset "predict" begin
434437
@testset "with MCMCChains.Chains" begin
435-
DynamicPPL.Random.seed!(100)
436-
437438
@model function linear_reg(x, y, σ=0.1)
438439
β ~ Normal(0, 1)
439440
for i in eachindex(y)
440441
y[i] ~ Normal* x[i], σ)
441442
end
443+
# Insert a := block to test that it is not included in predictions
444+
return σ2 := σ^2
442445
end
443446

444-
@model function linear_reg_vec(x, y, σ=0.1)
445-
β ~ Normal(0, 1)
446-
return y ~ MvNormal.* x, σ^2 * I)
447-
end
448-
447+
# Construct a chain with 'sampled values' of β
449448
ground_truth_β = 2
450449
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
451450

451+
# Generate predictions from that chain
452452
xs_test = [10 + 0.1, 10 + 2 * 0.1]
453453
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
454454
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)
455455

456-
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
457-
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
458-
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
459-
460-
# Ensure that `rng` is respected
461-
rng = MersenneTwister(42)
462-
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
463-
predictions2 = DynamicPPL.predict(
464-
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
465-
)
466-
@test all(Array(predictions1) .== Array(predictions2))
467-
468-
# Predict on two last indices for vectorized
469-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
470-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
471-
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
472-
473-
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
474-
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
456+
# Also test a vectorized model
457+
@model function linear_reg_vec(x, y, σ=0.1)
458+
β ~ Normal(0, 1)
459+
return y ~ MvNormal.* x, σ^2 * I)
460+
end
461+
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)
475462

476-
# Multiple chains
477-
multiple_β_chain = MCMCChains.Chains(
478-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
479-
)
480-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
481-
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
482-
@test size(multiple_β_chain, 3) == size(predictions, 3)
463+
@testset "variables in chain" begin
464+
# Note that this also checks that variables on the lhs of :=,
465+
# such as σ2, are not included in the resulting chain
466+
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
467+
end
483468

484-
for chain_idx in MCMCChains.chains(multiple_β_chain)
485-
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
469+
@testset "accuracy" begin
470+
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
486471
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
487472
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
488473
end
489474

490-
# Predict on two last indices for vectorized
491-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
492-
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
493-
494-
for chain_idx in MCMCChains.chains(multiple_β_chain)
495-
ys_pred_vec = vec(
496-
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
475+
@testset "ensure that rng is respected" begin
476+
rng = MersenneTwister(42)
477+
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
478+
predictions2 = DynamicPPL.predict(
479+
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
497480
)
481+
@test all(Array(predictions1) .== Array(predictions2))
482+
end
483+
484+
@testset "accuracy on vectorized model" begin
485+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
486+
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
487+
498488
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
499489
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
500490
end
491+
492+
@testset "prediction from multiple chains" begin
493+
# Normal linreg model
494+
multiple_β_chain = MCMCChains.Chains(
495+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
496+
)
497+
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
498+
@test size(multiple_β_chain, 3) == size(predictions, 3)
499+
500+
for chain_idx in MCMCChains.chains(multiple_β_chain)
501+
ys_pred = vec(
502+
mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)
503+
)
504+
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
505+
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
506+
end
507+
508+
# Vectorized linreg model
509+
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)
510+
511+
for chain_idx in MCMCChains.chains(multiple_β_chain)
512+
ys_pred_vec = vec(
513+
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
514+
)
515+
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
516+
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
517+
end
518+
end
501519
end
502520

503521
@testset "with AbstractVector{<:AbstractVarInfo}" begin

0 commit comments

Comments
 (0)