Skip to content

Commit 292e06b

Browse files
authored
Merge branch 'master' into hg/abstract-mcmc-1.0
2 parents bc99424 + ac84008 commit 292e06b

29 files changed

+1050
-847
lines changed

.travis.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ os:
88
- osx
99
julia:
1010
- 1.0
11-
- 1.1
12-
- 1.2
13-
- 1.3
11+
- 1
1412
- nightly
1513
matrix:
1614
allow_failures:
@@ -19,6 +17,6 @@ matrix:
1917
notifications:
2018
email: false
2119
after_success:
22-
- if [[ $TRAVIS_JULIA_VERSION = 1.3 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
20+
- if [[ $TRAVIS_JULIA_VERSION = 1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
2321
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())';
2422
fi

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1112

1213
[compat]
1314
AbstractMCMC = "1.0"
1415
Bijectors = "0.5.2, 0.6"
1516
Distributions = "0.22, 0.23"
1617
MacroTools = "0.5.1"
18+
ZygoteRules = "0.2"
1719
julia = "1"
1820

1921
[extras]

src/DynamicPPL.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
66
using MacroTools
7+
import ZygoteRules
78

8-
import Base: string,
9-
Symbol,
9+
import Base: Symbol,
1010
==,
1111
hash,
12-
in,
1312
getindex,
1413
setindex!,
1514
push!,
@@ -22,8 +21,7 @@ import Base: string,
2221
haskey
2322

2423
# VarInfo
25-
export VarName,
26-
AbstractVarInfo,
24+
export AbstractVarInfo,
2725
VarInfo,
2826
UntypedVarInfo,
2927
getlogp,
@@ -44,13 +42,14 @@ export VarName,
4442
link!,
4543
invlink!,
4644
tonamedtuple,
45+
#VarName
46+
VarName,
47+
inspace,
48+
subsumes,
4749
# Compiler
4850
ModelGen,
4951
@model,
5052
@varname,
51-
@varinfo,
52-
@logpdf,
53-
@sampler,
5453
# Utilities
5554
vectorize,
5655
reconstruct,
@@ -61,9 +60,12 @@ export VarName,
6160
vectorize,
6261
set_resume!,
6362
# Model
63+
ModelGen,
6464
Model,
65-
getmissing,
66-
runmodel!,
65+
getmissings,
66+
getargnames,
67+
getdefaults,
68+
getgenerator,
6769
# Samplers
6870
Sampler,
6971
SampleFromPrior,
@@ -91,6 +93,11 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))
9193
# Used here and overloaded in Turing
9294
function getspace end
9395

96+
# Necessary forward declarations
97+
abstract type AbstractVarInfo end
98+
abstract type AbstractContext end
99+
100+
94101
include("utils.jl")
95102
include("selector.jl")
96103
include("model.jl")
@@ -102,5 +109,6 @@ include("varinfo.jl")
102109
include("context_implementations.jl")
103110
include("compiler.jl")
104111
include("prob_macro.jl")
112+
include("compat/ad.jl")
105113

106114
end # module

src/compat/ad.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Prevent Zygote from differentiating push!
2+
# See https://github.com/TuringLang/Turing.jl/issues/1199
3+
ZygoteRules.@adjoint function push!(
4+
vi::VarInfo,
5+
vn::VarName,
6+
r,
7+
dist::Distribution,
8+
gidset::Set{Selector}
9+
)
10+
return push!(vi, vn, r, dist, gidset), _ -> nothing
11+
end

0 commit comments

Comments
 (0)