Skip to content

Commit 86fe1c6

Browse files
committed
added error-handling of invalid broadcasting statements
1 parent 8e7d164 commit 86fe1c6

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/compiler.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,17 @@ function unwrap_right_left_vns(
235235
left::AbstractArray,
236236
vn::VarName,
237237
)
238+
# Need to check that we don't end up double-counting log-probabilities.
239+
combined_axes = Broadcast.combine_axes(left, right)
240+
if prod(length, combined_axes) > length(left)
241+
throw(
242+
ArgumentError(
243+
"a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side",
244+
),
245+
)
246+
end
247+
248+
# Extract the sub-varnames.
238249
vns = map(CartesianIndices(left)) do i
239250
return Accessors.IndexLens(Tuple(i)) vn
240251
end

test/compiler.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,4 +729,13 @@ module Issue537 end
729729
res = model()
730730
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
731731
end
732+
733+
@testset "invalid .~ expressions" begin
734+
@model function demo_with_invalid_dot_tilde()
735+
m = Matrix{Float64}(undef, 1, 2)
736+
m .~ [Normal(); Normal()]
737+
end
738+
739+
@test_throws ArgumentError demo_with_invalid_dot_tilde()()
740+
end
732741
end

0 commit comments

Comments
 (0)