diff --git a/Project.toml b/Project.toml index 51bac6df5..817692342 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.32.0" +version = "0.32.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/contexts.jl b/src/contexts.jl index 9eb3d5ccb..b337e4750 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -244,7 +244,7 @@ adds the `Prefix` to all parameters. This context is useful in nested models to ensure that the names of the parameters are unique. -See also: [`@submodel`](@ref) +See also: [`to_submodel`](@ref) """ struct PrefixContext{Prefix,C} <: AbstractContext context::C diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 47b969e6c..8c18163e3 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -146,6 +146,26 @@ function _pointwise_tilde_observe( end end +# Note on submodels (penelopeysm) +# +# We don't need to overload tilde_observe!! for Sampleables (yet), because it +# is currently not possible to evaluate a model with a Sampleable on the RHS +# of an observe statement. +# +# Note that calling tilde_assume!! on a Sampleable does not necessarily imply +# that there are no observe statements inside the Sampleable. There could well +# be likelihood terms in there, which must be included in the returned logp. +# See e.g. the `demo_dot_assume_observe_submodel` demo model. +# +# This is handled by passing the same context to rand_like!!, which figures out +# which terms to include using the context, and also mutates the context and vi +# appropriately. Thus, we don't need to check against _include_prior(context) +# here. +function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) + value, vi = DynamicPPL.rand_like!!(right, context, vi) + return value, vi +end + function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) value, logp, vi = tilde_assume(context.context, right, vn, vi) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 92a69d9ad..c506e1ba3 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -437,7 +437,8 @@ end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - @submodel s, m = _prior_dot_assume() + priors ~ to_submodel(_prior_dot_assume(), false) + s, m = priors 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) @@ -462,7 +463,7 @@ function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function _likelihood_mltivariate_observe(s, m, x) +@model function _likelihood_multivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) end @@ -475,7 +476,9 @@ end m .~ Normal.(0, sqrt.(s)) # Submodel likelihood - @submodel _likelihood_mltivariate_observe(s, m, x) + # With to_submodel, we have to have a left-hand side variable to + # capture the result, so we just use a dummy variable + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end diff --git a/test/compiler.jl b/test/compiler.jl index 977c1156c..4dc9fcb24 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -382,34 +382,13 @@ module Issue537 end @test demo2()() == 42 end - @testset "@submodel is deprecated" begin - @model inner() = x ~ Normal() - @model outer() = @submodel x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer()() - ) - - @model outer_with_prefix() = @submodel prefix = "sub" x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer_with_prefix()() - ) - end - - @testset "submodel" begin + @testset "to_submodel" begin # No prefix, 1 level. @model function demo1(x) return x ~ Normal() end @model function demo2(x, y) - @submodel demo1(x) + _ignore ~ to_submodel(demo1(x), false) return y ~ Uniform() end # No observation. @@ -441,7 +420,7 @@ module Issue537 end # Check values makes sense. @model function demo3(x, y) - @submodel demo1(x) + _ignore ~ to_submodel(demo1(x), false) return y ~ Normal(x) end m = demo3(1000.0, missing) @@ -453,12 +432,10 @@ module Issue537 end x ~ Normal() return x end - @model function demo_useval(x, y) - @submodel prefix = "sub1" x1 = demo_return(x) - @submodel prefix = "sub2" x2 = demo_return(y) - - return z ~ Normal(x1 + x2 + 100, 1.0) + sub1 ~ to_submodel(demo_return(x)) + sub2 ~ to_submodel(demo_return(y)) + return z ~ Normal(sub1 + sub2 + 100, 1.0) end m = demo_useval(missing, missing) vi = VarInfo(m) @@ -472,13 +449,11 @@ module Issue537 end @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV} η ~ MvNormal(zeros(num_steps), I) δ = sqrt(1 - α^2) - x = TV(undef, num_steps) x[1] = η[1] @inbounds for t in 2:num_steps x[t] = @. α * x[t - 1] + δ * η[t] end - return @. μ + σ * x end @@ -486,7 +461,6 @@ module Issue537 end α ~ Uniform() μ ~ Normal() σ ~ truncated(Normal(), 0, Inf) - num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs @@ -613,14 +587,11 @@ module Issue537 end @model demo() = x ~ Normal() retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) - # Return-value when using `@submodel` + # Return-value when using `to_submodel` @model inner() = x ~ Normal() - # Without assignment. - @model outer() = @submodel inner() - @test outer()() isa Real - - # With assignment. - @model outer() = @submodel x = inner() + @model function outer() + return _ignore ~ to_submodel(inner()) + end @test outer()() isa Real # Edge-cases. @@ -720,8 +691,7 @@ module Issue537 end return (; x, y) end @model function demo_tracked_submodel() - @submodel (x, y) = demo_tracked() - return (; x, y) + return vals ~ to_submodel(demo_tracked(), false) end for model in [demo_tracked(), demo_tracked_submodel()] # Make sure it's runnable and `y` is present in the return-value. diff --git a/test/deprecated.jl b/test/deprecated.jl new file mode 100644 index 000000000..f12217983 --- /dev/null +++ b/test/deprecated.jl @@ -0,0 +1,57 @@ +@testset "deprecated" begin + @testset "@submodel" begin + @testset "is deprecated" begin + @model inner() = x ~ Normal() + @model outer() = @submodel x = inner() + @test_logs( + ( + :warn, + "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", + ), + outer()() + ) + + @model outer_with_prefix() = @submodel prefix = "sub" x = inner() + @test_logs( + ( + :warn, + "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", + ), + outer_with_prefix()() + ) + end + + @testset "prefixing still works correctly" begin + @model inner() = x ~ Normal() + @model function outer() + a = @submodel inner() + b = @submodel prefix = "sub" inner() + return a, b + end + @test outer()() isa Tuple{Float64,Float64} + vi = VarInfo(outer()) + @test @varname(x) in keys(vi) + @test @varname(var"sub.x") in keys(vi) + end + + @testset "logp is still accumulated properly" begin + @model inner_assume() = x ~ Normal() + @model inner_observe(x, y) = y ~ Normal(x) + @model function outer(b) + a = @submodel inner_assume() + @submodel inner_observe(a, b) + end + y_val = 1.0 + model = outer(y_val) + @test model() == y_val + + x_val = 1.5 + vi = VarInfo(outer(y_val)) + DynamicPPL.setindex!!(vi, x_val, @varname(x)) + @test logprior(model, vi) ≈ logpdf(Normal(), x_val) + @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) + @test logjoint(model, vi) ≈ + logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val) + end + end +end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 93b7c59be..5c0b2e090 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -13,7 +13,7 @@ loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( model, example_values... ) - logp_true = logprior(model, vi) + logprior_true = logprior(model, vi) # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(model, vi) @@ -30,18 +30,18 @@ lps_prior = pointwise_prior_logdensities(model, vi) @test :x ∉ DynamicPPL.getsym.(keys(lps_prior)) logp = sum(sum, values(lps_prior)) - @test logp ≈ logp_true + @test logp ≈ logprior_true # Compute both likelihood and logdensity of prior - # using the default DefaultContex + # using the default DefaultContext lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) - @test logp ≈ (logp_true + loglikelihood_true) + @test logp ≈ (logprior_true + loglikelihood_true) # Test that modifications of Setup are picked up lps = pointwise_logdensities(model, vi, mod_ctx2) logp = sum(sum, values(lps)) - @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 + @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end diff --git a/test/runtests.jl b/test/runtests.jl index fdd59c7b6..9f2d21990 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,7 @@ include("test_util.jl") include("serialization.jl") include("pointwise_logdensities.jl") include("lkj.jl") + include("deprecated.jl") end if GROUP == "All" || GROUP == "Group2"