Skip to content

Commit e1a8111

Browse files
torfjeldeyebai
andcommitted
Reopening of previous PR (#309)
What we need before merge: 1. [x] Updated benchmarks. 2. [ ] Resolve comments. Co-authored-by: Hong Ge <[email protected]>
1 parent ad545be commit e1a8111

28 files changed

+1566
-260
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
matrix:
1818
version:
19-
- '1.3' # minimum supported version
19+
# - '1.3' # minimum supported version
2020
- '1' # current stable version
2121
os:
2222
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.16.2"
3+
version = "0.17.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[deps]
22
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
45
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
56

67
[compat]
78
Distributions = "0.25"
89
Documenter = "0.27"
10+
Setfield = "0.7.1, 0.8"
911
StableRNGs = "1"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ makedocs(;
88
sitename="DynamicPPL",
99
format=Documenter.HTML(),
1010
modules=[DynamicPPL],
11-
pages=["Home" => "index.md"],
11+
pages=["Home" => "index.md", "TestUtils" => "test_utils.md"],
1212
strict=true,
1313
checkdocs=:exports,
1414
doctestfilters=[

docs/src/test_utils.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# DynamicPPL.TestUtils
2+
3+
```@autodocs
4+
Modules = [DynamicPPL.TestUtils]
5+
```

src/DynamicPPL.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore
1010
using MacroTools: MacroTools
1111
using ZygoteRules: ZygoteRules
1212
using BangBang: BangBang
13+
using Setfield: Setfield
1314

1415
using Setfield: Setfield
1516
using BangBang: BangBang
@@ -31,15 +32,23 @@ import Base:
3132
keys,
3233
haskey
3334

35+
using BangBang: push!!, empty!!, setindex!!
36+
3437
# VarInfo
3538
export AbstractVarInfo,
3639
VarInfo,
3740
UntypedVarInfo,
3841
TypedVarInfo,
42+
SimpleVarInfo,
43+
push!!,
44+
empty!!,
3945
getlogp,
4046
setlogp!,
4147
acclogp!,
4248
resetlogp!,
49+
setlogp!!,
50+
acclogp!!,
51+
resetlogp!!,
4352
get_num_produce,
4453
set_num_produce!,
4554
reset_num_produce!,
@@ -139,13 +148,32 @@ include("distribution_wrappers.jl")
139148
include("contexts.jl")
140149
include("varinfo.jl")
141150
include("threadsafe.jl")
151+
include("simple_varinfo.jl")
142152
include("context_implementations.jl")
143153
include("compiler.jl")
144154
include("prob_macro.jl")
145155
include("compat/ad.jl")
146156
include("loglikelihoods.jl")
147157
include("submodel_macro.jl")
148-
149158
include("test_utils.jl")
150159

160+
# Deprecations
161+
@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo)
162+
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(
163+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution
164+
)
165+
@deprecate push!(
166+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler
167+
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler)
168+
@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(
169+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector
170+
)
171+
@deprecate push!(
172+
vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}
173+
) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector})
174+
175+
@deprecate setlogp!(vi, logp) setlogp!!(vi, logp)
176+
@deprecate acclogp!(vi, logp) acclogp!!(vi, logp)
177+
@deprecate resetlogp!(vi) resetlogp!!(vi)
178+
151179
end # module

src/compat/ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# See https://github.com/TuringLang/Turing.jl/issues/1199
2-
ChainRulesCore.@non_differentiable push!(
2+
ChainRulesCore.@non_differentiable push!!(
33
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
44
)
55

@@ -16,7 +16,7 @@ ZygoteRules.@adjoint function dot_observe(
1616
)
1717
function dot_observe_fallback(spl, dists, value, vi)
1818
increment_num_produce!(vi)
19-
return sum(map(Distributions.loglikelihood, dists, value))
19+
return sum(map(Distributions.loglikelihood, dists, value)), vi
2020
end
2121
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
2222
end

src/compiler.jl

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,12 @@ end
355355

356356
function generate_tilde_literal(left, right)
357357
# If the LHS is a literal, it is always an observation
358+
@gensym value
358359
return quote
359-
$(DynamicPPL.tilde_observe!)(
360+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
360361
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
361362
)
363+
$value
362364
end
363365
end
364366

@@ -373,7 +375,7 @@ function generate_tilde(left, right)
373375

374376
# Otherwise it is determined by the model or its value,
375377
# if the LHS represents an observation
376-
@gensym vn isassumption
378+
@gensym vn isassumption value
377379

378380
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379381
# that in DynamicPPL we the entire function body. Instead we should be
@@ -389,32 +391,38 @@ function generate_tilde(left, right)
389391
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
390392
end
391393

392-
$(DynamicPPL.tilde_observe!)(
394+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
393395
__context__,
394396
$(DynamicPPL.check_tilde_rhs)($right),
395397
$(maybe_view(left)),
396398
$vn,
397399
__varinfo__,
398400
)
401+
$value
399402
end
400403
end
401404
end
402405

403406
function generate_tilde_assume(left, right, vn)
404-
expr = :(
405-
$left = $(DynamicPPL.tilde_assume!)(
407+
# HACK: Because the Setfield.jl macro does not support assignment
408+
# with multiple arguments on the LHS, we need to capture the return-values
409+
# and then update the LHS variables one by one.
410+
@gensym value
411+
expr = :($left = $value)
412+
if left isa Expr
413+
expr = AbstractPPL.drop_escape(
414+
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
415+
)
416+
end
417+
418+
return quote
419+
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
406420
__context__,
407421
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
408422
__varinfo__,
409423
)
410-
)
411-
412-
return if left isa Expr
413-
AbstractPPL.drop_escape(
414-
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
415-
)
416-
else
417-
return expr
424+
$expr
425+
$value
418426
end
419427
end
420428

@@ -428,7 +436,7 @@ function generate_dot_tilde(left, right)
428436

429437
# Otherwise it is determined by the model or its value,
430438
# if the LHS represents an observation
431-
@gensym vn isassumption
439+
@gensym vn isassumption value
432440
return quote
433441
$vn = $(AbstractPPL.drop_escape(varname(left)))
434442
$isassumption = $(DynamicPPL.isassumption(left))
@@ -440,13 +448,14 @@ function generate_dot_tilde(left, right)
440448
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
441449
end
442450

443-
$(DynamicPPL.dot_tilde_observe!)(
451+
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
444452
__context__,
445453
$(DynamicPPL.check_tilde_rhs)($right),
446454
$(maybe_view(left)),
447455
$vn,
448456
__varinfo__,
449457
)
458+
$value
450459
end
451460
end
452461
end
@@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn)
455464
# We don't need to use `Setfield.@set` here since
456465
# `.=` is always going to be inplace + needs `left` to
457466
# be something that supports `.=`.
458-
return :(
459-
$left .= $(DynamicPPL.dot_tilde_assume!)(
467+
@gensym value
468+
return quote
469+
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
460470
__context__,
461471
$(DynamicPPL.unwrap_right_left_vns)(
462472
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
463473
)...,
464474
__varinfo__,
465475
)
466-
)
476+
$left .= $value
477+
$value
478+
end
479+
end
480+
481+
# Note that we cannot use `MacroTools.isdef` because
482+
# of https://github.com/FluxML/MacroTools.jl/issues/154.
483+
"""
484+
isfuncdef(expr)
485+
486+
Return `true` if `expr` is any form of function definition, and `false` otherwise.
487+
"""
488+
function isfuncdef(e::Expr)
489+
return if Meta.isexpr(e, :function)
490+
# Classic `function f(...)`
491+
true
492+
elseif Meta.isexpr(e, :->)
493+
# Anonymous functions/lambdas, e.g. `do` blocks or `->` defs.
494+
true
495+
elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call)
496+
# Short function defs, e.g. `f(args...) = ...`.
497+
true
498+
else
499+
false
500+
end
501+
end
502+
503+
"""
504+
replace_returns(expr)
505+
506+
Return `Expr` with all `return ...` statements replaced with
507+
`return ..., DynamicPPL.return_values(__varinfo__)`.
508+
509+
Note that this method will _not_ replace `return` statements within function
510+
definitions. This is checked using [`isfuncdef`](@ref).
511+
"""
512+
replace_returns(e) = e
513+
function replace_returns(e::Expr)
514+
if isfuncdef(e)
515+
return e
516+
end
517+
518+
if Meta.isexpr(e, :return)
519+
# NOTE: `return` always has an argument. In the case of
520+
# an empty `return`, the lowered expression will be `return nothing`.
521+
# Hence we don't need any special handling for empty returns.
522+
retval_expr = if length(e.args) > 1
523+
Expr(:tuple, e.args...)
524+
else
525+
e.args[1]
526+
end
527+
528+
return :(return ($retval_expr, __varinfo__))
529+
end
530+
531+
return Expr(e.head, map(replace_returns, e.args)...)
532+
end
533+
534+
# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`.
535+
make_returns_explicit!(body) = Expr(:return, body)
536+
function make_returns_explicit!(body::Expr)
537+
# If the last statement is a return-statement, we don't do anything.
538+
# Otherwise we replace the last statement with a `return` statement.
539+
if !Meta.isexpr(body.args[end], :return)
540+
body.args[end] = Expr(:return, body.args[end])
541+
end
542+
return body
467543
end
468544

469545
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
@@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode)
496572
# Replace the user-provided function body with the version created by DynamicPPL.
497573
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
498574
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
499-
# to the call site
575+
# to the call site.
576+
# NOTE: We need to replace statements of the form `return ...` with
577+
# `return (..., __varinfo__)` to ensure that the second
578+
# element in the returned value is always the most up-to-date `__varinfo__`.
579+
# See the docstrings of `replace_returns` for more info.
500580
evaluatordef[:body] = MacroTools.@q begin
501581
$(linenumbernode)
502-
$(modelinfo[:body])
582+
$(replace_returns(make_returns_explicit!(modelinfo[:body])))
503583
end
504584

505585
## Build the model function.

0 commit comments

Comments
 (0)