Skip to content
Closed
11 changes: 11 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ function unwrap_right_left_vns(
left::AbstractArray,
vn::VarName,
)
# Need to check that we don't end up double-counting log-probabilities.
combined_axes = Broadcast.combine_axes(left, right)
if prod(length, combined_axes) > length(left)
throw(
ArgumentError(
"a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side",
),
)
end

# Extract the sub-varnames.
vns = map(CartesianIndices(left)) do i
return Accessors.IndexLens(Tuple(i)) ∘ vn
end
Expand Down
66 changes: 66 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,72 @@
)
end

# `FixedContext`
function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return dot_tilde_assume(childcontext(context), right, left, vns, vi)
end

# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# We _might_ also have some of the variables fixed, but not all.
logp = 0

Check warning on line 343 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L343

Added line #L343 was not covered by tests
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)

Check warning on line 353 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L347-L353

Added lines #L347 - L353 were not covered by tests
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(

Check warning on line 356 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L356

Added line #L356 was not covered by tests
childcontext(context), right_bc[I_right...], vn, vi
)
logp += logp_inner

Check warning on line 359 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L359

Added line #L359 was not covered by tests
end
end
end

Check warning on line 362 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L361-L362

Added lines #L361 - L362 were not covered by tests

return left, logp, vi

Check warning on line 364 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L364

Added line #L364 was not covered by tests
end

function dot_tilde_assume(
rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi
)
if !has_fixed_symbol(context, first(vns))
# Defer to `childcontext`.
return dot_tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi)
end
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
# So we need to check each of the vns.
logp = 0

Check warning on line 376 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L376

Added line #L376 was not covered by tests
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
# then be compiled away in cases where the `Symbol` is not present.
left_bc = Broadcast.broadcastable(left)
right_bc = Broadcast.broadcastable(right)
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
vn = vns[I_left...]
if hasfixed(context, vn)
left[I_left...] = getfixed(context, vn)

Check warning on line 386 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L380-L386

Added lines #L380 - L386 were not covered by tests
else
# Defer to `tilde_assume`.
left[I_left...], logp_inner, vi = tilde_assume(

Check warning on line 389 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L389

Added line #L389 was not covered by tests
rng, childcontext(context), sampler, right_bc[I_right...], vn, vi
)
logp += logp_inner

Check warning on line 392 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L392

Added line #L392 was not covered by tests
end
end
end

Check warning on line 395 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L394-L395

Added lines #L394 - L395 were not covered by tests

return left, logp, vi

Check warning on line 397 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L397

Added line #L397 was not covered by tests
end

"""
dot_tilde_assume!!(context, right, left, vn, vi)

Expand Down
7 changes: 7 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,13 @@ NodeTrait(::FixedContext) = IsParent()
childcontext(context::FixedContext) = context.context
setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child)

has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn)

has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn)
@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names,sym}
return sym in names
end

"""
hasfixed(context::AbstractContext, vn::VarName)

Expand Down
9 changes: 9 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,13 @@ module Issue537 end
res = model()
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
end

@testset "invalid .~ expressions" begin
@model function demo_with_invalid_dot_tilde()
m = Matrix{Float64}(undef, 1, 2)
return m .~ [Normal(); Normal()]
end

@test_throws ArgumentError demo_with_invalid_dot_tilde()()
end
end
Loading