Skip to content

Commit 3a09b52

Browse files
committed
Add ThreadSafeVarInfo
1 parent 180458e commit 3a09b52

File tree

7 files changed

+177
-87
lines changed

7 files changed

+177
-87
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1312
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1413

1514
[compat]
1615
AbstractMCMC = "1.0"
1716
Bijectors = "0.5.2, 0.6"
1817
Distributions = "0.22, 0.23"
1918
MacroTools = "0.5.1"
20-
StaticArrays = "0.12.2"
2119
ZygoteRules = "0.2"
2220
julia = "1"
2321

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Bijectors
66
using MacroTools
77

88
import AbstractMCMC
9-
import StaticArrays
109
import ZygoteRules
1110

1211
import Random
@@ -111,6 +110,7 @@ include("varname.jl")
111110
include("distribution_wrappers.jl")
112111
include("contexts.jl")
113112
include("varinfo.jl")
113+
include("threadsafe.jl")
114114
include("context_implementations.jl")
115115
include("compiler.jl")
116116
include("prob_macro.jl")

src/compiler.jl

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
22
"Distributions."
33

4-
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_logps)
4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
55

66
"""
77
isassumption(expr)
@@ -234,25 +234,25 @@ function generate_tilde(left, right, args)
234234
$isassumption = $(DynamicPPL.isassumption(left))
235235
if $isassumption
236236
$left = $(DynamicPPL.tilde_assume)(
237-
_context, _sampler, $tmpright, $vn, $inds, _varinfo, _logps)
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
238238
else
239239
$(DynamicPPL.tilde_observe)(
240-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
241241
end
242242
end
243243
end
244244

245245
return quote
246246
$(top...)
247247
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
248-
_varinfo, _logps)
248+
_varinfo)
249249
end
250250
end
251251

252252
# If the LHS is a literal, it is always an observation
253253
return quote
254254
$(top...)
255-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo, _logps)
255+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
256256
end
257257
end
258258

@@ -279,26 +279,25 @@ function generate_dot_tilde(left, right, args)
279279
$isassumption = $(DynamicPPL.isassumption(left))
280280
if $isassumption
281281
$left .= $(DynamicPPL.dot_tilde_assume)(
282-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
282+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
283283
else
284284
$(DynamicPPL.dot_tilde_observe)(
285-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
285+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
286286
end
287287
end
288288
end
289289

290290
return quote
291291
$(top...)
292292
$left .= $(DynamicPPL.dot_tilde_assume)(
293-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
293+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
294294
end
295295
end
296296

297297
# If the LHS is a literal, it is always an observation
298298
return quote
299299
$(top...)
300-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo,
301-
_logps)
300+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
302301
end
303302
end
304303

@@ -335,29 +334,16 @@ function build_output(model_info)
335334
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
336335
end
337336

338-
@gensym(evaluator, innerevaluator, generator)
337+
@gensym(evaluator, generator)
339338
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
340339
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
341340

342341
return quote
343342
function $evaluator(
344-
model::$(DynamicPPL.Model),
345-
varinfo::$(DynamicPPL.VarInfo),
346-
sampler::$(DynamicPPL.AbstractSampler),
347-
context::$(DynamicPPL.AbstractContext),
348-
)
349-
logps = $(DynamicPPL.initlogps)(varinfo)
350-
result = $innerevaluator(model, varinfo, sampler, context, logps)
351-
$(DynamicPPL.acclogp!)(varinfo, $(Base.sum)(logps))
352-
return result
353-
end
354-
355-
function $innerevaluator(
356343
_model::$(DynamicPPL.Model),
357-
_varinfo::$(DynamicPPL.VarInfo),
344+
_varinfo::$(DynamicPPL.AbstractVarInfo),
358345
_sampler::$(DynamicPPL.AbstractSampler),
359346
_context::$(DynamicPPL.AbstractContext),
360-
_logps,
361347
)
362348
$unwrap_data_expr
363349
$main_body

src/context_implementations.jl

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3636
end
3737

3838
"""
39-
tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
39+
tilde_assume(ctx, sampler, right, vn, inds, vi)
4040
4141
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42-
accumulate the log probability in `logps` (separately for each thread), and return the
43-
sampled value.
42+
accumulate the log probability, and return the sampled value.
4443
4544
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4645
"""
47-
function tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
46+
function tilde_assume(ctx, sampler, right, vn, inds, vi)
4847
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
49-
logps[Threads.threadid()] += logp
48+
acclogp!(vi, logp)
5049
return value
5150
end
5251

@@ -76,29 +75,28 @@ end
7675
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7776
7877
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
79-
accumulate the log probability in `logps` (separately for each thread), and return the
80-
observed value.
78+
accumulate the log probability, and return the observed value.
8179
8280
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
8381
and indices; if needed, these can be accessed through this function, though.
8482
"""
85-
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
83+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
8684
logp = tilde(ctx, sampler, right, left, vi)
87-
logps[Threads.threadid()] += logp
85+
acclogp!(vi, logp)
8886
return left
8987
end
9088

9189
"""
92-
tilde_observe(ctx, sampler, right, left, vi, logps)
90+
tilde_observe(ctx, sampler, right, left, vi)
9391
94-
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95-
(separately for each thread), and return the observed value.
92+
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and
93+
return the observed value.
9694
9795
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9896
"""
99-
function tilde_observe(ctx, sampler, right, left, vi, logps)
97+
function tilde_observe(ctx, sampler, right, left, vi)
10098
logp = tilde(ctx, sampler, right, left, vi)
101-
logps[Threads.threadid()] += logp
99+
acclogp!(vi, logp)
102100
return left
103101
end
104102

@@ -117,7 +115,7 @@ function assume(
117115
spl::Union{SampleFromPrior,SampleFromUniform},
118116
dist::Distribution,
119117
vn::VarName,
120-
vi::VarInfo,
118+
vi,
121119
)
122120
if haskey(vi, vn)
123121
# Always overwrite the parameters with new ones for `SampleFromUniform`.
@@ -142,7 +140,7 @@ function observe(
142140
spl::Union{SampleFromPrior, SampleFromUniform},
143141
dist::Distribution,
144142
value,
145-
vi::VarInfo,
143+
vi,
146144
)
147145
increment_num_produce!(vi)
148146
return Distributions.logpdf(dist, value)
@@ -201,14 +199,13 @@ end
201199
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
202200
203201
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
204-
model inputs), accumulate the log probability in `logps` (separately for each thread), and
205-
return the sampled value.
202+
model inputs), accumulate the log probability, and return the sampled value.
206203
207204
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
208205
"""
209-
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi, logps)
206+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
210207
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
211-
logps[Threads.threadid()] += logp
208+
acclogp!(vi, logp)
212209
return value
213210
end
214211

@@ -240,7 +237,7 @@ function _dot_tilde(
240237
right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}},
241238
left::AbstractMatrix{>:AbstractVector},
242239
vn::AbstractVector{<:VarName},
243-
vi::VarInfo,
240+
vi,
244241
)
245242
throw(ambiguity_error_msg())
246243
end
@@ -250,7 +247,7 @@ function dot_assume(
250247
dist::MultivariateDistribution,
251248
vns::AbstractVector{<:VarName},
252249
var::AbstractMatrix,
253-
vi::VarInfo,
250+
vi,
254251
)
255252
@assert length(dist) == size(var, 1)
256253
r = get_and_set_val!(vi, vns, dist, spl)
@@ -263,7 +260,7 @@ function dot_assume(
263260
dists::Union{Distribution, AbstractArray{<:Distribution}},
264261
vns::AbstractArray{<:VarName},
265262
var::AbstractArray,
266-
vi::VarInfo,
263+
vi,
267264
)
268265
r = get_and_set_val!(vi, vns, dists, spl)
269266
# Make sure `r` is not a matrix for multivariate distributions
@@ -276,13 +273,13 @@ function dot_assume(
276273
::Any,
277274
::AbstractArray{<:VarName},
278275
::Any,
279-
::VarInfo
276+
::Any,
280277
)
281278
error("[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement")
282279
end
283280

284281
function get_and_set_val!(
285-
vi::VarInfo,
282+
vi,
286283
vns::AbstractVector{<:VarName},
287284
dist::MultivariateDistribution,
288285
spl::Union{SampleFromPrior,SampleFromUniform},
@@ -313,7 +310,7 @@ function get_and_set_val!(
313310
end
314311

315312
function get_and_set_val!(
316-
vi::VarInfo,
313+
vi,
317314
vns::AbstractArray{<:VarName},
318315
dists::Union{Distribution, AbstractArray{<:Distribution}},
319316
spl::Union{SampleFromPrior,SampleFromUniform},
@@ -344,7 +341,7 @@ function get_and_set_val!(
344341
end
345342

346343
function set_val!(
347-
vi::VarInfo,
344+
vi,
348345
vns::AbstractVector{<:VarName},
349346
dist::MultivariateDistribution,
350347
val::AbstractMatrix,
@@ -356,7 +353,7 @@ function set_val!(
356353
return val
357354
end
358355
function set_val!(
359-
vi::VarInfo,
356+
vi,
360357
vns::AbstractArray{<:VarName},
361358
dists::Union{Distribution, AbstractArray{<:Distribution}},
362359
val::AbstractArray,
@@ -384,36 +381,34 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
384381
end
385382

386383
"""
387-
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
384+
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
388385
389386
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
390-
accumulate the log probability in `logps` (separately for each thread), and return the
391-
observed value.
387+
accumulate the log probability, and return the observed value.
392388
393389
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
394390
name and indices; if needed, these can be accessed through this function, though.
395391
"""
396-
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi, logps)
392+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
397393
logp = dot_tilde(ctx, sampler, right, left, vi)
398-
logps[Threads.threadid()] += logp
394+
acclogp!(vi, logp)
399395
return left
400396
end
401397

402398
"""
403-
dot_tilde_observe(ctx, sampler, right, left, vi, logps)
399+
dot_tilde_observe(ctx, sampler, right, left, vi)
404400
405401
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
406-
probability in `logps` (separately for each thread), and return the observed value.
402+
probability, and return the observed value.
407403
408404
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
409405
"""
410-
function dot_tilde_observe(ctx, sampler, right, left, vi, logps)
406+
function dot_tilde_observe(ctx, sampler, right, left, vi)
411407
logp = dot_tilde(ctx, sampler, right, left, vi)
412-
logps[Threads.threadid()] += logp
408+
acclogp!(vi, logp)
413409
return left
414410
end
415411

416-
417412
function _dot_tilde(sampler, right, left::AbstractArray, vi)
418413
return dot_observe(sampler, right, left, vi)
419414
end
@@ -422,7 +417,7 @@ function _dot_tilde(
422417
sampler::AbstractSampler,
423418
right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}},
424419
left::AbstractMatrix{>:AbstractVector},
425-
vi::VarInfo,
420+
vi,
426421
)
427422
throw(ambiguity_error_msg())
428423
end
@@ -431,7 +426,7 @@ function dot_observe(
431426
spl::Union{SampleFromPrior, SampleFromUniform},
432427
dist::MultivariateDistribution,
433428
value::AbstractMatrix,
434-
vi::VarInfo,
429+
vi,
435430
)
436431
increment_num_produce!(vi)
437432
DynamicPPL.DEBUG && @debug "dist = $dist"
@@ -442,7 +437,7 @@ function dot_observe(
442437
spl::Union{SampleFromPrior, SampleFromUniform},
443438
dists::Union{Distribution, AbstractArray{<:Distribution}},
444439
value::AbstractArray,
445-
vi::VarInfo,
440+
vi,
446441
)
447442
increment_num_produce!(vi)
448443
DynamicPPL.DEBUG && @debug "dists = $dists"
@@ -453,7 +448,7 @@ function dot_observe(
453448
spl::Sampler,
454449
::Any,
455450
::Any,
456-
::VarInfo,
451+
::Any,
457452
)
458453
error("[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement")
459454
end

0 commit comments

Comments
 (0)