Skip to content

Commit b842c03

Browse files
committed
Merge branch 'dev' into phg/empty_varinfo_print
2 parents 14e5cb1 + a673228 commit b842c03

File tree

15 files changed

+194
-112
lines changed

15 files changed

+194
-112
lines changed

.travis.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@ branches:
1212
os:
1313
- linux
1414
- osx
15-
julia:
16-
- 1.0
17-
- 1
18-
- nightly
15+
matrix:
16+
include:
17+
- julia: 1.0
18+
- julia: 1
19+
env: JULIA_NUM_THREADS=1
20+
- julia: 1
21+
env: JULIA_NUM_THREADS=2
22+
- julia: nightly
23+
env: JULIA_NUM_THREADS=1
24+
- julia: nightly
25+
env: JULIA_NUM_THREADS=2
1926
matrix:
2027
allow_failures:
2128
- julia: nightly
2229
fast_finish: true
2330
notifications:
2431
email: false
2532
after_success:
26-
- if [[ $TRAVIS_JULIA_VERSION = 1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
33+
- if [[ $TRAVIS_JULIA_VERSION = 1 ]] && [[ $JULIA_NUM_THREADS = 1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
2734
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())';
2835
fi

src/compiler.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
22
"Distributions."
33

4-
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
55

66
"""
77
isassumption(expr)
@@ -240,7 +240,7 @@ function generate_tilde(left, right, args)
240240
$isassumption = $(DynamicPPL.isassumption(left))
241241
if $isassumption
242242
$left = $(DynamicPPL.tilde_assume)(
243-
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
243+
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
244244
else
245245
$(DynamicPPL.tilde_observe)(
246246
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
@@ -250,8 +250,8 @@ function generate_tilde(left, right, args)
250250

251251
return quote
252252
$(top...)
253-
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
254-
_varinfo)
253+
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn,
254+
$inds, _varinfo)
255255
end
256256
end
257257

@@ -285,7 +285,7 @@ function generate_dot_tilde(left, right, args)
285285
$isassumption = $(DynamicPPL.isassumption(left))
286286
if $isassumption
287287
$left .= $(DynamicPPL.dot_tilde_assume)(
288-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
288+
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
289289
else
290290
$(DynamicPPL.dot_tilde_observe)(
291291
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
@@ -296,7 +296,7 @@ function generate_dot_tilde(left, right, args)
296296
return quote
297297
$(top...)
298298
$left .= $(DynamicPPL.dot_tilde_assume)(
299-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
299+
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
300300
end
301301
end
302302

@@ -346,6 +346,7 @@ function build_output(model_info)
346346

347347
return quote
348348
function $evaluator(
349+
_rng::$(Random.AbstractRNG),
349350
_model::$(DynamicPPL.Model),
350351
_varinfo::$(DynamicPPL.AbstractVarInfo),
351352
_sampler::$(DynamicPPL.AbstractSampler),

src/context_implementations.jl

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,47 +19,47 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
1919
_getindex(x, inds::Tuple{}) = x
2020

2121
# assume
22-
function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
23-
return _tilde(sampler, right, vn, vi)
22+
function tilde(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
23+
return _tilde(rng, sampler, right, vn, vi)
2424
end
25-
function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
25+
function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
2626
if ctx.vars !== nothing
2727
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
2828
settrans!(vi, false, vn)
2929
end
30-
return _tilde(sampler, right, vn, vi)
30+
return _tilde(rng, sampler, right, vn, vi)
3131
end
32-
function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
32+
function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
3333
if ctx.vars !== nothing
3434
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
3535
settrans!(vi, false, vn)
3636
end
37-
return _tilde(sampler, NoDist(right), vn, vi)
37+
return _tilde(rng, sampler, NoDist(right), vn, vi)
3838
end
39-
function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
40-
return tilde(ctx.ctx, sampler, right, left, inds, vi)
39+
function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
40+
return tilde(rng, ctx.ctx, sampler, right, left, inds, vi)
4141
end
4242

4343
"""
44-
tilde_assume(ctx, sampler, right, vn, inds, vi)
44+
tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
4545
4646
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
4747
accumulate the log probability, and return the sampled value.
4848
49-
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
49+
Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
5050
"""
51-
function tilde_assume(ctx, sampler, right, vn, inds, vi)
52-
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
51+
function tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
52+
value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi)
5353
acclogp!(vi, logp)
5454
return value
5555
end
5656

5757

58-
function _tilde(sampler, right, vn::VarName, vi)
59-
return assume(sampler, right, vn, vi)
58+
function _tilde(rng, sampler, right, vn::VarName, vi)
59+
return assume(rng, sampler, right, vn, vi)
6060
end
61-
function _tilde(sampler, right::NamedDist, vn::VarName, vi)
62-
return _tilde(sampler, right.dist, right.name, vi)
61+
function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi)
62+
return _tilde(rng, sampler, right.dist, right.name, vi)
6363
end
6464

6565
# observe
@@ -108,7 +108,7 @@ end
108108

109109
_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi)
110110

111-
function assume(spl::Sampler, dist)
111+
function assume(rng, spl::Sampler, dist)
112112
error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
113113
end
114114

@@ -117,6 +117,7 @@ function observe(spl::Sampler, weight)
117117
end
118118

119119
function assume(
120+
rng,
120121
spl::Union{SampleFromPrior,SampleFromUniform},
121122
dist::Distribution,
122123
vn::VarName,
@@ -126,15 +127,15 @@ function assume(
126127
# Always overwrite the parameters with new ones for `SampleFromUniform`.
127128
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
128129
unset_flag!(vi, vn, "del")
129-
r = init(dist, spl)
130+
r = init(rng, dist, spl)
130131
vi[vn] = vectorize(dist, r)
131132
settrans!(vi, false, vn)
132133
setorder!(vi, vn, get_num_produce(vi))
133134
else
134135
r = vi[vn]
135136
end
136137
else
137-
r = init(dist, spl)
138+
r = init(rng, dist, spl)
138139
push!(vi, vn, r, dist, spl)
139140
settrans!(vi, false, vn)
140141
end
@@ -154,11 +155,12 @@ end
154155
# .~ functions
155156

156157
# assume
157-
function dot_tilde(ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi)
158+
function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi)
158159
vns, dist = get_vns_and_dist(right, left, vn)
159-
return _dot_tilde(sampler, dist, left, vns, vi)
160+
return _dot_tilde(rng, sampler, dist, left, vns, vi)
160161
end
161162
function dot_tilde(
163+
rng,
162164
ctx::LikelihoodContext,
163165
sampler,
164166
right,
@@ -175,12 +177,13 @@ function dot_tilde(
175177
else
176178
vns, dist = get_vns_and_dist(right, left, vn)
177179
end
178-
return _dot_tilde(sampler, NoDist(dist), left, vns, vi)
180+
return _dot_tilde(rng, sampler, NoDist(dist), left, vns, vi)
179181
end
180-
function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi)
181-
return dot_tilde(ctx.ctx, sampler, right, left, vn, inds, vi)
182+
function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi)
183+
return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi)
182184
end
183185
function dot_tilde(
186+
rng,
184187
ctx::PriorContext,
185188
sampler,
186189
right,
@@ -197,19 +200,19 @@ function dot_tilde(
197200
else
198201
vns, dist = get_vns_and_dist(right, left, vn)
199202
end
200-
return _dot_tilde(sampler, dist, left, vns, vi)
203+
return _dot_tilde(rng, sampler, dist, left, vns, vi)
201204
end
202205

203206
"""
204-
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
207+
dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)
205208
206209
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
207210
model inputs), accumulate the log probability, and return the sampled value.
208211
209-
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
212+
Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`.
210213
"""
211-
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
212-
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
214+
function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)
215+
value, logp = dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)
213216
acclogp!(vi, logp)
214217
return value
215218
end
@@ -232,12 +235,13 @@ function get_vns_and_dist(
232235
return getvn.(CartesianIndices(var)), dist
233236
end
234237

235-
function _dot_tilde(sampler, right, left, vns::AbstractArray{<:VarName}, vi)
236-
return dot_assume(sampler, right, vns, left, vi)
238+
function _dot_tilde(rng, sampler, right, left, vns::AbstractArray{<:VarName}, vi)
239+
return dot_assume(rng, sampler, right, vns, left, vi)
237240
end
238241

239242
# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
240243
function _dot_tilde(
244+
rng,
241245
sampler::AbstractSampler,
242246
right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}},
243247
left::AbstractMatrix{>:AbstractVector},
@@ -248,32 +252,35 @@ function _dot_tilde(
248252
end
249253

250254
function dot_assume(
255+
rng,
251256
spl::Union{SampleFromPrior, SampleFromUniform},
252257
dist::MultivariateDistribution,
253258
vns::AbstractVector{<:VarName},
254259
var::AbstractMatrix,
255260
vi,
256261
)
257262
@assert length(dist) == size(var, 1)
258-
r = get_and_set_val!(vi, vns, dist, spl)
263+
r = get_and_set_val!(rng, vi, vns, dist, spl)
259264
lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1])))
260265
var .= r
261266
return var, lp
262267
end
263268
function dot_assume(
269+
rng,
264270
spl::Union{SampleFromPrior, SampleFromUniform},
265271
dists::Union{Distribution, AbstractArray{<:Distribution}},
266272
vns::AbstractArray{<:VarName},
267273
var::AbstractArray,
268274
vi,
269275
)
270-
r = get_and_set_val!(vi, vns, dists, spl)
276+
r = get_and_set_val!(rng, vi, vns, dists, spl)
271277
# Make sure `r` is not a matrix for multivariate distributions
272278
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
273279
var .= r
274280
return var, lp
275281
end
276282
function dot_assume(
283+
rng,
277284
spl::Sampler,
278285
::Any,
279286
::AbstractArray{<:VarName},
@@ -284,6 +291,7 @@ function dot_assume(
284291
end
285292

286293
function get_and_set_val!(
294+
rng,
287295
vi,
288296
vns::AbstractVector{<:VarName},
289297
dist::MultivariateDistribution,
@@ -294,7 +302,7 @@ function get_and_set_val!(
294302
# Always overwrite the parameters with new ones for `SampleFromUniform`.
295303
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
296304
unset_flag!(vi, vns[1], "del")
297-
r = init(dist, spl, n)
305+
r = init(rng, dist, spl, n)
298306
for i in 1:n
299307
vn = vns[i]
300308
vi[vn] = vectorize(dist, r[:, i])
@@ -305,7 +313,7 @@ function get_and_set_val!(
305313
r = vi[vns]
306314
end
307315
else
308-
r = init(dist, spl, n)
316+
r = init(rng, dist, spl, n)
309317
for i in 1:n
310318
vn = vns[i]
311319
push!(vi, vn, r[:,i], dist, spl)
@@ -316,6 +324,7 @@ function get_and_set_val!(
316324
end
317325

318326
function get_and_set_val!(
327+
rng,
319328
vi,
320329
vns::AbstractArray{<:VarName},
321330
dists::Union{Distribution, AbstractArray{<:Distribution}},
@@ -325,7 +334,7 @@ function get_and_set_val!(
325334
# Always overwrite the parameters with new ones for `SampleFromUniform`.
326335
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
327336
unset_flag!(vi, vns[1], "del")
328-
f = (vn, dist) -> init(dist, spl)
337+
f = (vn, dist) -> init(rng, dist, spl)
329338
r = f.(vns, dists)
330339
for i in eachindex(vns)
331340
vn = vns[i]
@@ -338,7 +347,7 @@ function get_and_set_val!(
338347
r = reshape(vi[vec(vns)], size(vns))
339348
end
340349
else
341-
f = (vn, dist) -> init(dist, spl)
350+
f = (vn, dist) -> init(rng, dist, spl)
342351
r = f.(vns, dists)
343352
push!.(Ref(vi), vns, r, dists, Ref(spl))
344353
settrans!.(Ref(vi), false, vns)

src/distribution_wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct NoDist{
3030
end
3131
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
3232

33-
Distributions.rand(d::NoDist) = rand(d.dist)
33+
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
3434
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
3535
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
3636
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})

0 commit comments

Comments
 (0)