Skip to content

Commit f7531ba

Browse files
committed
Fix use of arrays of Distributions (#245)
This PR fixes #28 (comment) and allows to use arbitrary arrays of `Distribution`s. This was already allowed in the context implementations but prevented by a check in the code generated by the `@model` macro. Additionally, the PR replaces the hard-coded check with a `check_tilde_rhs` function which, IMO, makes the code a bit simpler and easier to read. Moreover, a bug in the `dot_assume` implementation for arrays of Distributions is fixed. Co-authored-by: David Widmann <[email protected]>
1 parent 4c17629 commit f7531ba

File tree

6 files changed

+128
-57
lines changed

6 files changed

+128
-57
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.18"
3+
version = "0.10.19"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 87 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
2-
"Distributions."
3-
41
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
52
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
63

@@ -38,6 +35,20 @@ end
3835
# failsafe: a literal is never an assumption
3936
isassumption(expr) = :(false)
4037

38+
"""
39+
check_tilde_rhs(x)
40+
41+
Check if the right-hand side `x` of a `~` is a `Distribution` or an array of
42+
`Distributions`, then return `x`.
43+
"""
44+
function check_tilde_rhs(@nospecialize(x))
45+
return throw(ArgumentError(
46+
"the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s"
47+
))
48+
end
49+
check_tilde_rhs(x::Distribution) = x
50+
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
51+
4152
#################
4253
# Main Compiler #
4354
#################
@@ -225,34 +236,47 @@ Generate an `observe` expression for data variables and `assume` expression for
225236
variables.
226237
"""
227238
function generate_tilde(left, right)
228-
@gensym tmpright
229-
top = [:($tmpright = $right),
230-
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
231-
|| throw(ArgumentError($DISTMSG)))]
232-
233-
if left isa Symbol || left isa Expr
234-
@gensym out vn inds isassumption
235-
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
236-
239+
# If the LHS is a literal, it is always an observation
240+
if !(left isa Symbol || left isa Expr)
237241
return quote
238-
$(top...)
239-
$isassumption = $(DynamicPPL.isassumption(left))
240-
if $isassumption
241-
$left = $(DynamicPPL.tilde_assume)(
242-
__rng__, __context__, __sampler__, $tmpright, $vn, $inds, __varinfo__
243-
)
244-
else
245-
$(DynamicPPL.tilde_observe)(
246-
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
247-
)
248-
end
242+
$(DynamicPPL.tilde_observe)(
243+
__context__,
244+
__sampler__,
245+
$(DynamicPPL.check_tilde_rhs)($right),
246+
$left,
247+
__varinfo__,
248+
)
249249
end
250250
end
251251

252-
# If the LHS is a literal, it is always an observation
252+
# Otherwise it is determined by the model or its value,
253+
# if the LHS represents an observation
254+
@gensym vn inds isassumption
253255
return quote
254-
$(top...)
255-
$(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
256+
$vn = $(varname(left))
257+
$inds = $(vinds(left))
258+
$isassumption = $(DynamicPPL.isassumption(left))
259+
if $isassumption
260+
$left = $(DynamicPPL.tilde_assume)(
261+
__rng__,
262+
__context__,
263+
__sampler__,
264+
$(DynamicPPL.check_tilde_rhs)($right),
265+
$vn,
266+
$inds,
267+
__varinfo__,
268+
)
269+
else
270+
$(DynamicPPL.tilde_observe)(
271+
__context__,
272+
__sampler__,
273+
$(DynamicPPL.check_tilde_rhs)($right),
274+
$left,
275+
$vn,
276+
$inds,
277+
__varinfo__,
278+
)
279+
end
256280
end
257281
end
258282

@@ -262,34 +286,48 @@ end
262286
Generate the expression that replaces `left .~ right` in the model body.
263287
"""
264288
function generate_dot_tilde(left, right)
265-
@gensym tmpright
266-
top = [:($tmpright = $right),
267-
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
268-
|| throw(ArgumentError($DISTMSG)))]
269-
270-
if left isa Symbol || left isa Expr
271-
@gensym out vn inds isassumption
272-
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
273-
289+
# If the LHS is a literal, it is always an observation
290+
if !(left isa Symbol || left isa Expr)
274291
return quote
275-
$(top...)
276-
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
277-
if $isassumption
278-
$left .= $(DynamicPPL.dot_tilde_assume)(
279-
__rng__, __context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
280-
)
281-
else
282-
$(DynamicPPL.dot_tilde_observe)(
283-
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
284-
)
285-
end
292+
$(DynamicPPL.dot_tilde_observe)(
293+
__context__,
294+
__sampler__,
295+
$(DynamicPPL.check_tilde_rhs)($right),
296+
$left,
297+
__varinfo__,
298+
)
286299
end
287300
end
288301

289-
# If the LHS is a literal, it is always an observation
302+
# Otherwise it is determined by the model or its value,
303+
# if the LHS represents an observation
304+
@gensym vn inds isassumption
290305
return quote
291-
$(top...)
292-
$(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
306+
$vn = $(varname(left))
307+
$inds = $(vinds(left))
308+
$isassumption = $(DynamicPPL.isassumption(left))
309+
if $isassumption
310+
$left .= $(DynamicPPL.dot_tilde_assume)(
311+
__rng__,
312+
__context__,
313+
__sampler__,
314+
$(DynamicPPL.check_tilde_rhs)($right),
315+
$left,
316+
$vn,
317+
$inds,
318+
__varinfo__,
319+
)
320+
else
321+
$(DynamicPPL.dot_tilde_observe)(
322+
__context__,
323+
__sampler__,
324+
$(DynamicPPL.check_tilde_rhs)($right),
325+
$left,
326+
$vn,
327+
$inds,
328+
__varinfo__,
329+
)
330+
end
293331
end
294332
end
295333

src/context_implementations.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,7 @@ function dot_observe(
468468
increment_num_produce!(vi)
469469
@debug "dists = $dists"
470470
@debug "value = $value"
471-
return sum(zip(dists, value)) do (d, v)
472-
Distributions.loglikelihood(d, v)
473-
end
471+
return sum(Distributions.loglikelihood.(dists, value))
474472
end
475473
function dot_observe(
476474
spl::Sampler,

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1919

2020
[compat]
2121
AbstractMCMC = "2.1, 3.0"
22-
AbstractPPL = "0.1.2"
22+
AbstractPPL = "0.1.3"
2323
Bijectors = "0.8.2, 0.9"
2424
Distributions = "0.24, 0.25"
2525
DistributionsAD = "0.6.3"

test/compiler.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,11 @@ end
253253
vi2 = VarInfo(f2())
254254
vi3 = VarInfo(f3())
255255
@test haskey(vi1.metadata, :y)
256-
@test vi1.metadata.y.vns[1] == VarName(:y)
256+
@test vi1.metadata.y.vns[1] == VarName{:y}()
257257
@test haskey(vi2.metadata, :y)
258-
@test vi2.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1)))
258+
@test vi2.metadata.y.vns[1] == VarName{:y}(((2,), (Colon(), 1)))
259259
@test haskey(vi3.metadata, :y)
260-
@test vi3.metadata.y.vns[1] == VarName(:y, ((1,),))
260+
@test vi3.metadata.y.vns[1] == VarName{:y}(((1,),))
261261
end
262262
@testset "custom tilde" begin
263263
@model demo() = begin
@@ -313,4 +313,14 @@ end
313313
end
314314
@test demo2()() == 42
315315
end
316+
317+
@testset "check_tilde_rhs" begin
318+
@test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())
319+
320+
x = Normal()
321+
@test DynamicPPL.check_tilde_rhs(x) === x
322+
323+
x = [Laplace(), Normal(), MvNormal(3, 1.0)]
324+
@test DynamicPPL.check_tilde_rhs(x) === x
325+
end
316326
end

test/context_implementations.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,29 @@
1313

1414
test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext())
1515
end
16+
17+
# https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577
18+
@testset "arrays of distributions" begin
19+
@model function test(x, y)
20+
y .~ Normal.(x)
21+
end
22+
23+
for ysize in ((2,), (2, 3), (2, 3, 4))
24+
# drop trailing dimensions
25+
for xsize in ntuple(i -> ysize[1:i], length(ysize))
26+
x = randn(xsize)
27+
y = randn(ysize)
28+
z = logjoint(test(x, y), VarInfo())
29+
@test z sum(logpdf.(Normal.(x), y))
30+
end
31+
32+
# singleton dimensions
33+
for xsize in ntuple(i -> (ysize[1:(i-1)]..., 1, ysize[(i+1):end]...), length(ysize))
34+
x = randn(xsize)
35+
y = randn(ysize)
36+
z = logjoint(test(x, y), VarInfo())
37+
@test z sum(logpdf.(Normal.(x), y))
38+
end
39+
end
40+
end
1641
end

0 commit comments

Comments
 (0)