Skip to content

Commit b76bfaf

Browse files
authored
Merge pull request #31 from TuringLang/mt/zygote_ad
Zygote support - workaround setproperty! and getproperty
2 parents 177017f + 1b6e5fc commit b76bfaf

File tree

20 files changed

+313
-211
lines changed

20 files changed

+313
-211
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111

1212
[compat]
1313
AbstractMCMC = "~0.1"
14-
AdvancedHMC = "0.2.15"
15-
Bijectors = "0.4.0, 0.5"
16-
Distributions = "0.21.11, 0.22"
17-
DistributionsAD = "0.1.2"
14+
AdvancedHMC = "0.2.20"
15+
Bijectors = "0.5.2"
16+
Distributions = "0.22"
17+
DistributionsAD = "0.2"
1818
ForwardDiff = "0.10.3"
1919
Libtask = "0.3.1"
2020
LogDensityProblems = "^0.9, 0.10"

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ export VarName,
3333
setlogp!,
3434
acclogp!,
3535
resetlogp!,
36+
get_num_produce,
37+
set_num_produce!,
38+
reset_num_produce!,
39+
increment_num_produce!,
3640
set_retained_vns_del_by_spl!,
3741
is_flagged,
3842
unset_flag!,

src/compiler.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ function replace_logpdf!(model_info)
365365
vi = model_info[:main_body_names][:vi]
366366
ex = MacroTools.postwalk(ex) do x
367367
if @capture(x, @logpdf())
368-
:($vi.logp)
368+
:($vi.logp[])
369369
else
370370
x
371371
end
@@ -456,16 +456,22 @@ function tilde(left, right, model_info)
456456
$vn, $inds = $preprocessed
457457
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
458458
$left = $out[1]
459-
$vi.logp += $out[2]
459+
DynamicPPL.acclogp!($vi, $out[2])
460460
else
461-
$vi.logp += DynamicPPL.tilde($ctx, $sampler, $temp_right, $preprocessed, $vi)
461+
DynamicPPL.acclogp!(
462+
$vi,
463+
DynamicPPL.tilde($ctx, $sampler, $temp_right, $preprocessed, $vi),
464+
)
462465
end
463466
end
464467
else
465468
ex = quote
466469
$temp_right = $right
467470
$assert_ex
468-
$vi.logp += DynamicPPL.tilde($ctx, $sampler, $temp_right, $left, $vi)
471+
DynamicPPL.acclogp!(
472+
$vi,
473+
DynamicPPL.tilde($ctx, $sampler, $temp_right, $left, $vi),
474+
)
469475
end
470476
end
471477
return ex
@@ -500,18 +506,24 @@ function dot_tilde(left, right, model_info)
500506
$temp_left = $left
501507
$out = DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vn, $inds, $vi)
502508
$left .= $out[1]
503-
$vi.logp += $out[2]
509+
DynamicPPL.acclogp!($vi, $out[2])
504510
else
505511
$temp_left = $preprocessed
506-
$vi.logp += DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi)
512+
DynamicPPL.acclogp!(
513+
$vi,
514+
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
515+
)
507516
end
508517
end
509518
else
510519
ex = quote
511520
$temp_left = $left
512521
$temp_right = $right
513522
$assert_ex
514-
$vi.logp += DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi)
523+
DynamicPPL.acclogp!(
524+
$vi,
525+
DynamicPPL.dot_tilde($ctx, $sampler, $temp_right, $temp_left, $vi),
526+
)
515527
end
516528
end
517529
return ex
@@ -587,7 +599,7 @@ function build_output(model_info)
587599
$model
588600
)
589601
$unwrap_data_expr
590-
$vi.logp = 0
602+
DynamicPPL.resetlogp!($vi)
591603
$main_body
592604
end
593605
return DynamicPPL.Model($inner_function, $args_nt, $model_gen_constructor)

src/prob_macro.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function logprior(
142142
@assert n in keys(left) "Variable $n is not defined."
143143
end
144144
model(vi, SampleFromPrior(), PriorContext(left))
145-
return vi.logp
145+
return getlogp(vi)
146146
end
147147
@generated function get_prior_model_args(
148148
left::NamedTuple{namesl},
@@ -205,14 +205,14 @@ function loglikelihood(
205205
c = chain[i]
206206
_setval!(vi, c)
207207
model(vi, SampleFromPrior(), ctx)
208-
vi.logp
208+
return getlogp(vi)
209209
end
210210
else
211211
# Likelihood without chain
212212
# Rhs values are used in the context
213213
ctx = LikelihoodContext(right)
214214
model(vi, SampleFromPrior(), ctx)
215-
return vi.logp
215+
return getlogp(vi)
216216
end
217217
end
218218
@generated function get_like_model_args(

src/varinfo.jl

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
147147
end
148148
function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
149149
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
150-
VarInfo(md, Base.RefValue{eltype(x)}(old_vi.logp), Ref(old_vi.num_produce))
150+
VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)))
151151
end
152152
@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space}
153153
exprs = []
@@ -654,36 +654,25 @@ function TypedVarInfo(vi::UntypedVarInfo)
654654
sym_dists, sym_gids, sym_orders, sym_flags)
655655
)
656656
end
657-
logp = vi.logp
658-
num_produce = vi.num_produce
657+
logp = getlogp(vi)
658+
num_produce = get_num_produce(vi)
659659
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
660660
return VarInfo(nt, Ref(logp), Ref(num_produce))
661661
end
662662
TypedVarInfo(vi::TypedVarInfo) = vi
663663

664-
function getproperty(vi::VarInfo, f::Symbol)
665-
f === :logp && return getfield(vi, :logp)[]
666-
f === :num_produce && return getfield(vi, :num_produce)[]
667-
return getfield(vi, f)
668-
end
669-
function setproperty!(vi::VarInfo, f::Symbol, x)
670-
f === :logp && return getfield(vi, :logp)[] = x
671-
f === :num_produce && return getfield(vi, :num_produce)[] = x
672-
return setfield!(vi, f, x)
673-
end
674-
675664
"""
676665
empty!(vi::VarInfo)
677666
678-
Empty the fields of `vi.metadata` and reset `vi.logp` and `vi.num_produce` to
667+
Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to
679668
zeros.
680669
681670
This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`.
682671
"""
683672
function empty!(vi::VarInfo)
684673
_empty!(vi.metadata)
685-
vi.logp = 0
686-
vi.num_produce = 0
674+
resetlogp!(vi)
675+
reset_num_produce!(vi)
687676
return vi
688677
end
689678
@inline _empty!(metadata::Metadata) = empty!(metadata)
@@ -724,31 +713,61 @@ istrans(vi::AbstractVarInfo, vn::VarName) = is_flagged(vi, vn, "trans")
724713
Return the log of the joint probability of the observed data and parameters sampled in
725714
`vi`.
726715
"""
727-
getlogp(vi::AbstractVarInfo) = vi.logp
716+
getlogp(vi::AbstractVarInfo) = vi.logp[]
728717

729718
"""
730719
setlogp!(vi::VarInfo, logp::Real)
731720
732721
Set the log of the joint probability of the observed data and parameters sampled in
733722
`vi` to `logp`.
734723
"""
735-
setlogp!(vi::AbstractVarInfo, logp::Real) = vi.logp = logp
724+
setlogp!(vi::AbstractVarInfo, logp::Real) = vi.logp[] = logp
736725

737726
"""
738727
acclogp!(vi::VarInfo, logp::Real)
739728
740729
Add `logp` to the value of the log of the joint probability of the observed data and
741730
parameters sampled in `vi`.
742731
"""
743-
acclogp!(vi::AbstractVarInfo, logp::Real) = vi.logp += logp
732+
acclogp!(vi::AbstractVarInfo, logp::Real) = vi.logp[] += logp
744733

745734
"""
746735
resetlogp!(vi::VarInfo)
747736
748737
Reset the value of the log of the joint probability of the observed data and parameters
749738
sampled in `vi` to 0.
750739
"""
751-
resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, 0.0)
740+
resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, 0)
741+
742+
743+
"""
744+
get_num_produce(vi::VarInfo)
745+
746+
Return the `num_produce` of `vi`.
747+
"""
748+
get_num_produce(vi::AbstractVarInfo) = vi.num_produce[]
749+
750+
"""
751+
set_num_produce!(vi::VarInfo, n::Int)
752+
753+
Set the `num_produce` field of `vi` to `n`.
754+
"""
755+
set_num_produce!(vi::AbstractVarInfo, n::Int) = vi.num_produce[] = n
756+
757+
"""
758+
increment_num_produce!(vi::VarInfo)
759+
760+
Add 1 to `num_produce` in `vi`.
761+
"""
762+
increment_num_produce!(vi::AbstractVarInfo) = vi.num_produce[] += 1
763+
764+
"""
765+
reset_num_produce!(vi::VarInfo)
766+
767+
Reset the value of `num_produce` the log of the joint probability of the observed data
768+
and parameters sampled in `vi` to 0.
769+
"""
770+
reset_num_produce!(vi::AbstractVarInfo) = set_num_produce!(vi, 0)
752771

753772
"""
754773
isempty(vi::VarInfo)
@@ -1035,8 +1054,8 @@ function show(io::IO, vi::UntypedVarInfo)
10351054
| Vals : $(vi.metadata.vals)
10361055
| GIDs : $(vi.metadata.gids)
10371056
| Orders : $(vi.metadata.orders)
1038-
| Logp : $(vi.logp)
1039-
| #produce : $(vi.num_produce)
1057+
| Logp : $(getlogp(vi))
1058+
| #produce : $(get_num_produce(vi))
10401059
| flags : $(vi.metadata.flags)
10411060
\\=======================================================================
10421061
"""
@@ -1103,7 +1122,7 @@ function push!(
11031122
append!(meta.vals, val)
11041123
push!(meta.dists, dist)
11051124
push!(meta.gids, gidset)
1106-
push!(meta.orders, vi.num_produce)
1125+
push!(meta.orders, get_num_produce(vi))
11071126
push!(meta.flags["del"], false)
11081127
push!(meta.flags["trans"], false)
11091128

@@ -1149,18 +1168,18 @@ end
11491168
"""
11501169
set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
11511170
1152-
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce` to `true`.
1171+
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`.
11531172
"""
11541173
function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler)
11551174
# Get the indices of `vns` that belong to `spl` as a vector
11561175
gidcs = _getidcs(vi, spl)
1157-
if vi.num_produce == 0
1176+
if get_num_produce(vi) == 0
11581177
for i = length(gidcs):-1:1
11591178
vi.metadata.flags["del"][gidcs[i]] = true
11601179
end
11611180
else
11621181
for i in 1:length(vi.orders)
1163-
if i in gidcs && vi.orders[i] > vi.num_produce
1182+
if i in gidcs && vi.orders[i] > get_num_produce(vi)
11641183
vi.metadata.flags["del"][i] = true
11651184
end
11661185
end
@@ -1170,7 +1189,7 @@ end
11701189
function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler)
11711190
# Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol
11721191
gidcs = _getidcs(vi, spl)
1173-
return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, vi.num_produce)
1192+
return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi))
11741193
end
11751194
@generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names}
11761195
expr = Expr(:block)

test/Turing/contrib/inference/AdvancedSMCExtensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first:
255255

256256
# Run SMC & CSMC nodes
257257
for j in 1:spl.alg.n_nodes
258-
VarInfos[j].num_produce = 0
258+
reset_num_produce!(VarInfos[j])
259259
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1]
260260
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
261261
end

test/Turing/core/Core.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Core
22

3+
using Bijectors
34
using MacroTools, Libtask, ForwardDiff, Random
45
using Distributions, LinearAlgebra
56
using ..Utilities, Reexport
@@ -13,9 +14,14 @@ import Bijectors: link, invlink
1314
using DistributionsAD
1415
using StatsFuns: logsumexp, softmax
1516
@reexport using DynamicPPL
17+
using Requires
1618

1719
include("container.jl")
1820
include("ad.jl")
21+
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
22+
include("compat/zygote.jl")
23+
export ZygoteAD
24+
end
1925

2026
export @model,
2127
@varname,

0 commit comments

Comments
 (0)