Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.32.0"
version = "0.32.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ adds the `Prefix` to all parameters.
This context is useful in nested models to ensure that the names of the parameters are
unique.

See also: [`@submodel`](@ref)
See also: [`to_submodel`](@ref)
"""
struct PrefixContext{Prefix,C} <: AbstractContext
context::C
Expand Down
20 changes: 20 additions & 0 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ function _pointwise_tilde_observe(
end
end

# Note on submodels (penelopeysm)
#
# We don't need to overload tilde_observe!! for Sampleables (yet), because it
# is currently not possible to evaluate a model with a Sampleable on the RHS
# of an observe statement.
#
# Note that calling tilde_assume!! on a Sampleable does not necessarily imply
# that there are no observe statements inside the Sampleable. There could well
# be likelihood terms in there, which must be included in the returned logp.
# See e.g. the `demo_dot_assume_observe_submodel` demo model.
#
# This is handled by passing the same context to rand_like!!, which figures out
# which terms to include using the context, and also mutates the context and vi
# appropriately. Thus, we don't need to check against _include_prior(context)
# here.
function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi)
value, vi = DynamicPPL.rand_like!!(right, context, vi)
return value, vi
end

function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
!_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi))
value, logp, vi = tilde_assume(context.context, right, vn, vi)
Expand Down
9 changes: 6 additions & 3 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ end

@model function demo_assume_submodel_observe_index_literal()
# Submodel prior
@submodel s, m = _prior_dot_assume()
priors ~ to_submodel(_prior_dot_assume(), false)
s, m = priors
1.5 ~ Normal(m[1], sqrt(s[1]))
2.0 ~ Normal(m[2], sqrt(s[2]))

Expand All @@ -462,7 +463,7 @@ function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
end

@model function _likelihood_mltivariate_observe(s, m, x)
@model function _likelihood_multivariate_observe(s, m, x)
return x ~ MvNormal(m, Diagonal(s))
end

Expand All @@ -475,7 +476,9 @@ end
m .~ Normal.(0, sqrt.(s))

# Submodel likelihood
@submodel _likelihood_mltivariate_observe(s, m, x)
# With to_submodel, we have to have a left-hand side variable to
# capture the result, so we just use a dummy variable
_ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x))

return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
end
Expand Down
52 changes: 11 additions & 41 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,34 +382,13 @@ module Issue537 end
@test demo2()() == 42
end

@testset "@submodel is deprecated" begin
@model inner() = x ~ Normal()
@model outer() = @submodel x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer()()
)

@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer_with_prefix()()
)
end

@testset "submodel" begin
@testset "to_submodel" begin
# No prefix, 1 level.
@model function demo1(x)
return x ~ Normal()
end
@model function demo2(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Uniform()
end
# No observation.
Expand Down Expand Up @@ -441,7 +420,7 @@ module Issue537 end

# Check values makes sense.
@model function demo3(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Normal(x)
end
m = demo3(1000.0, missing)
Expand All @@ -453,12 +432,10 @@ module Issue537 end
x ~ Normal()
return x
end

@model function demo_useval(x, y)
@submodel prefix = "sub1" x1 = demo_return(x)
@submodel prefix = "sub2" x2 = demo_return(y)

return z ~ Normal(x1 + x2 + 100, 1.0)
sub1 ~ to_submodel(demo_return(x))
sub2 ~ to_submodel(demo_return(y))
return z ~ Normal(sub1 + sub2 + 100, 1.0)
end
m = demo_useval(missing, missing)
vi = VarInfo(m)
Expand All @@ -472,21 +449,18 @@ module Issue537 end
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
η ~ MvNormal(zeros(num_steps), I)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
x[1] = η[1]
@inbounds for t in 2:num_steps
x[t] = @. α * x[t - 1] + δ * η[t]
end

return @. μ + σ * x
end

@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)

num_steps = length(y[1])
num_obs = length(y)
@inbounds for i in 1:num_obs
Expand Down Expand Up @@ -613,14 +587,11 @@ module Issue537 end
@model demo() = x ~ Normal()
retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext())

# Return-value when using `@submodel`
# Return-value when using `to_submodel`
@model inner() = x ~ Normal()
# Without assignment.
@model outer() = @submodel inner()
@test outer()() isa Real

# With assignment.
@model outer() = @submodel x = inner()
@model function outer()
return _ignore ~ to_submodel(inner())
end
@test outer()() isa Real

# Edge-cases.
Expand Down Expand Up @@ -720,8 +691,7 @@ module Issue537 end
return (; x, y)
end
@model function demo_tracked_submodel()
@submodel (x, y) = demo_tracked()
return (; x, y)
return vals ~ to_submodel(demo_tracked(), false)
end
for model in [demo_tracked(), demo_tracked_submodel()]
# Make sure it's runnable and `y` is present in the return-value.
Expand Down
57 changes: 57 additions & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
@testset "deprecated" begin
@testset "@submodel" begin
@testset "is deprecated" begin
@model inner() = x ~ Normal()
@model outer() = @submodel x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer()()
)

@model outer_with_prefix() = @submodel prefix = "sub" x = inner()
@test_logs(
(
:warn,
"`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.",
),
outer_with_prefix()()
)
end

@testset "prefixing still works correctly" begin
@model inner() = x ~ Normal()
@model function outer()
a = @submodel inner()
b = @submodel prefix = "sub" inner()
return a, b
end
@test outer()() isa Tuple{Float64,Float64}
vi = VarInfo(outer())
@test @varname(x) in keys(vi)
@test @varname(var"sub.x") in keys(vi)
end

@testset "logp is still accumulated properly" begin
@model inner_assume() = x ~ Normal()
@model inner_observe(x, y) = y ~ Normal(x)
@model function outer(b)
a = @submodel inner_assume()
@submodel inner_observe(a, b)
end
y_val = 1.0
model = outer(y_val)
@test model() == y_val

x_val = 1.5
vi = VarInfo(outer(y_val))
DynamicPPL.setindex!!(vi, x_val, @varname(x))
@test logprior(model, vi) ≈ logpdf(Normal(), x_val)
@test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val)
@test logjoint(model, vi) ≈
logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val)
end
end
end
10 changes: 5 additions & 5 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(
model, example_values...
)
logp_true = logprior(model, vi)
logprior_true = logprior(model, vi)

# Compute the pointwise loglikelihoods.
lls = pointwise_loglikelihoods(model, vi)
Expand All @@ -30,18 +30,18 @@
lps_prior = pointwise_prior_logdensities(model, vi)
@test :x ∉ DynamicPPL.getsym.(keys(lps_prior))
logp = sum(sum, values(lps_prior))
@test logp ≈ logp_true
@test logp ≈ logprior_true

# Compute both likelihood and logdensity of prior
# using the default DefaultContex
# using the default DefaultContext
lps = pointwise_logdensities(model, vi)
logp = sum(sum, values(lps))
@test logp ≈ (logp_true + loglikelihood_true)
@test logp ≈ (logprior_true + loglikelihood_true)

# Test that modifications of Setup are picked up
lps = pointwise_logdensities(model, vi, mod_ctx2)
logp = sum(sum, values(lps))
@test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4
@test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4
end
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ include("test_util.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
include("deprecated.jl")
end

if GROUP == "All" || GROUP == "Group2"
Expand Down
Loading