diff --git a/Project.toml b/Project.toml index 97969944d..60dbcdc81 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.32.2" +version = "0.33.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 06cde3bac..7fcbd6a7c 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -117,7 +117,7 @@ function DynamicPPL.predict( DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) model(rng, varinfo, DynamicPPL.SampleFromPrior()) - vals = DynamicPPL.values_as_in_model(model, varinfo) + vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, vcat, diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index c5003d53a..16556ee8c 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -22,22 +22,22 @@ $(TYPEDFIELDS) struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext "values that are extracted from the model" values::T + "whether to extract variables on the LHS of :=" + include_colon_eq::Bool "child context" context::C end - -ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext()) -function ValuesAsInModelContext(context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), context) +function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) + return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) end NodeTrait(::ValuesAsInModelContext) = IsParent() childcontext(context::ValuesAsInModelContext) = context.context function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, child) + return ValuesAsInModelContext(context.values, context.include_colon_eq, child) end -is_extracting_values(context::ValuesAsInModelContext) = true +is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq function is_extracting_values(context::AbstractContext) return is_extracting_values(NodeTrait(context), context) end @@ -114,8 +114,8 @@ function dot_tilde_assume( end """ - values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) - values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext]) Get the values of `varinfo` as they would be seen in the model. @@ -132,6 +132,7 @@ of additional model evaluations. # Arguments - `model::Model`: model to extract realizations from. +- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. - `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. @@ -183,24 +184,26 @@ false julia> # Approach 2: Extract realizations using `values_as_in_model`. # (✓) `values_as_in_model` will re-run the model and extract # the correct realization of `y` given the new values of `x`. - lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub + lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub true ``` """ function values_as_in_model( model::Model, + include_colon_eq::Bool, varinfo::AbstractVarInfo=VarInfo(), context::AbstractContext=DefaultContext(), ) - context = ValuesAsInModelContext(context) + context = ValuesAsInModelContext(include_colon_eq, context) evaluate!!(model, varinfo, context) return context.values end function values_as_in_model( rng::Random.AbstractRNG, model::Model, + include_colon_eq::Bool, varinfo::AbstractVarInfo=VarInfo(), context::AbstractContext=DefaultContext(), ) - return values_as_in_model(model, varinfo, SamplingContext(rng, context)) + return values_as_in_model(model, true, varinfo, SamplingContext(rng, context)) end diff --git a/test/compiler.jl b/test/compiler.jl index 4dc9fcb24..051eba618 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -702,10 +702,17 @@ module Issue537 end @test haskey(varinfo, @varname(x)) @test !haskey(varinfo, @varname(y)) - # While `values_as_in_model` should contain both `x` and `y`. - values = values_as_in_model(model, deepcopy(varinfo)) + # While `values_as_in_model` should contain both `x` and `y`, if + # include_colon_eq is set to `true`. + values = values_as_in_model(model, true, deepcopy(varinfo)) @test haskey(values, @varname(x)) @test haskey(values, @varname(y)) + + # And if include_colon_eq is set to `false`, then `values` should + # only contain `x`. + values = values_as_in_model(model, false, deepcopy(varinfo)) + @test haskey(values, @varname(x)) + @test !haskey(values, @varname(y)) end end diff --git a/test/model.jl b/test/model.jl index cb1dbc735..96c0f1560 100644 --- a/test/model.jl +++ b/test/model.jl @@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - realizations = values_as_in_model(model, varinfo) + # We can set the include_colon_eq arg to false because none of + # the demo models contain :=. The behaviour when + # include_colon_eq is true is tested in test/compiler.jl + realizations = values_as_in_model(model, false, varinfo) # Ensure that all variables are found. vns_found = collect(keys(realizations)) @test vns ∩ vns_found == vns ∪ vns_found @@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end + + @testset "check that sampling obeys rng if passed" begin + @model function f() + x ~ Normal(0) + return y ~ Normal(x) + end + model = f() + # Call values_as_in_model with the rng + values = values_as_in_model(Random.Xoshiro(43), model, false) + # Check that they match the values that would be used if vi was seeded + # with that seed instead + expected_vi = VarInfo(Random.Xoshiro(43), model) + for vn in keys(values) + @test values[vn] == expected_vi[vn] + end + end end @testset "Erroneous model call" begin @@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "predict" begin @testset "with MCMCChains.Chains" begin - DynamicPPL.Random.seed!(100) - @model function linear_reg(x, y, σ=0.1) β ~ Normal(0, 1) for i in eachindex(y) y[i] ~ Normal(β * x[i], σ) end + # Insert a := block to test that it is not included in predictions + return σ2 := σ^2 end - @model function linear_reg_vec(x, y, σ=0.1) - β ~ Normal(0, 1) - return y ~ MvNormal(β .* x, σ^2 * I) - end - + # Construct a chain with 'sampled values' of β ground_truth_β = 2 β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) predictions = DynamicPPL.predict(m_lin_reg_test, β_chain) - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) - @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 - - # Ensure that `rng` is respected - rng = MersenneTwister(42) - predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) - predictions2 = DynamicPPL.predict( - MersenneTwister(42), m_lin_reg_test, β_chain[1:2] - ) - @test all(Array(predictions1) .== Array(predictions2)) - - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain) - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) - - @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 - @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + # Also test a vectorized model + @model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) + end + m_lin_reg_test_vec = linear_reg_vec(xs_test, missing) - # Multiple chains - multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] - ) - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) - @test size(multiple_β_chain, 3) == size(predictions, 3) + @testset "variables in chain" begin + # Note that this also checks that variables on the lhs of :=, + # such as σ2, are not included in the resulting chain + @test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")]) + end - for chain_idx in MCMCChains.chains(multiple_β_chain) - ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) + @testset "accuracy" begin + ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 end - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) - - for chain_idx in MCMCChains.chains(multiple_β_chain) - ys_pred_vec = vec( - mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) + @testset "ensure that rng is respected" begin + rng = MersenneTwister(42) + predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) + predictions2 = DynamicPPL.predict( + MersenneTwister(42), m_lin_reg_test, β_chain[1:2] ) + @test all(Array(predictions1) .== Array(predictions2)) + end + + @testset "accuracy on vectorized model" begin + predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 end + + @testset "prediction from multiple chains" begin + # Normal linreg model + multiple_β_chain = MCMCChains.Chains( + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + ) + predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + @test size(multiple_β_chain, 3) == size(predictions, 3) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred = vec( + mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1) + ) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + + # Vectorized linreg model + predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred_vec = vec( + mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) + ) + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + end end @testset "with AbstractVector{<:AbstractVarInfo}" begin