Skip to content
Closed
11 changes: 11 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ function unwrap_right_left_vns(
left::AbstractArray,
vn::VarName,
)
# Need to check that we don't end up double-counting log-probabilities.
combined_axes = Broadcast.combine_axes(left, right)
if prod(length, combined_axes) > length(left)
throw(
ArgumentError(
"a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side",
),
)
end

# Extract the sub-varnames.
vns = map(CartesianIndices(left)) do i
return Accessors.IndexLens(Tuple(i)) ∘ vn
end
Expand Down
173 changes: 70 additions & 103 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,6 @@ function tilde_assume(
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
end
Expand Down Expand Up @@ -328,37 +290,6 @@ function dot_tilde_assume(
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,47 +299,83 @@ function dot_tilde_assume(
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
end

function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
return dot_tilde_assume(
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
# `FixedContext`
function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return tilde_assume(childcontext(context), right, left, vns, vi)
end

# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# We _might_ also have some of the variables fixed, but not all.
logp = 0
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(
childcontext(context), right_bc[I_right...], vn, vi
)
logp += logp_inner
end
end
end

return left, logp, vi
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
)
function dot_tilde_assume(
rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi
)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi)
end
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# So we need to check each of the vns.
logp = 0
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(
rng, childcontext(context), sampler, right_bc[I_right...], vn, vi
)
logp += logp_inner
end
end
end

return left, logp, vi
end

"""
Expand Down
36 changes: 15 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext

The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext

The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down Expand Up @@ -514,6 +501,13 @@ NodeTrait(::FixedContext) = IsParent()
childcontext(context::FixedContext) = context.context
setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child)

has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn)

has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn)
@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names,sym}
return sym in names
end

"""
hasfixed(context::AbstractContext, vn::VarName)

Expand Down
9 changes: 9 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,4 +729,13 @@ module Issue537 end
res = model()
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
end

@testset "invalid .~ expressions" begin
@model function demo_with_invalid_dot_tilde()
m = Matrix{Float64}(undef, 1, 2)
return m .~ [Normal(); Normal()]
end

@test_throws ArgumentError demo_with_invalid_dot_tilde()()
end
end
Loading