Skip to content

Commit 6727849

Browse files
committed
Make accumulation of log probabilities thread-safe
1 parent 83f8f98 commit 6727849

File tree

6 files changed

+105
-34
lines changed

6 files changed

+105
-34
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1314

1415
[compat]
1516
AbstractMCMC = "1.0"
1617
Bijectors = "0.5.2, 0.6"
1718
Distributions = "0.22, 0.23"
1819
MacroTools = "0.5.1"
20+
StaticArrays = "0.12.2"
1921
ZygoteRules = "0.2"
2022
julia = "1"
2123

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ using Bijectors
66
using MacroTools
77

88
import AbstractMCMC
9-
import Random
9+
import StaticArrays
1010
import ZygoteRules
1111

12+
import Random
13+
1214
import Base: Symbol,
1315
==,
1416
hash,

src/compiler.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
22
"Distributions."
33

4-
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_logps)
55

66
"""
77
isassumption(expr)
@@ -234,24 +234,25 @@ function generate_tilde(left, right, args)
234234
$isassumption = $(DynamicPPL.isassumption(left))
235235
if $isassumption
236236
$left = $(DynamicPPL.tilde_assume)(
237-
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo, _logps)
238238
else
239239
$(DynamicPPL.tilde_observe)(
240-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
241241
end
242242
end
243243
end
244244

245245
return quote
246246
$(top...)
247-
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
247+
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
248+
_varinfo, _logps)
248249
end
249250
end
250251

251252
# If the LHS is a literal, it is always an observation
252253
return quote
253254
$(top...)
254-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
255+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo, _logps)
255256
end
256257
end
257258

@@ -278,25 +279,26 @@ function generate_dot_tilde(left, right, args)
278279
$isassumption = $(DynamicPPL.isassumption(left))
279280
if $isassumption
280281
$left .= $(DynamicPPL.dot_tilde_assume)(
281-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
282+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
282283
else
283284
$(DynamicPPL.dot_tilde_observe)(
284-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
285+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
285286
end
286287
end
287288
end
288289

289290
return quote
290291
$(top...)
291292
$left .= $(DynamicPPL.dot_tilde_assume)(
292-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
293+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
293294
end
294295
end
295296

296297
# If the LHS is a literal, it is always an observation
297298
return quote
298299
$(top...)
299-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
300+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo,
301+
_logps)
300302
end
301303
end
302304

@@ -333,16 +335,29 @@ function build_output(model_info)
333335
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
334336
end
335337

336-
@gensym(evaluator, generator)
338+
@gensym(evaluator, innerevaluator, generator)
337339
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
338340
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
339341

340342
return quote
341343
function $evaluator(
344+
model::$(DynamicPPL.Model),
345+
varinfo::$(DynamicPPL.VarInfo),
346+
sampler::$(DynamicPPL.AbstractSampler),
347+
context::$(DynamicPPL.AbstractContext),
348+
)
349+
logps = $(DynamicPPL.initlogps)(varinfo)
350+
result = $innerevaluator(model, varinfo, sampler, context, logps)
351+
$(DynamicPPL.acclogp!)(varinfo, $(Base.sum)(logps))
352+
return result
353+
end
354+
355+
function $innerevaluator(
342356
_model::$(DynamicPPL.Model),
343357
_varinfo::$(DynamicPPL.VarInfo),
344358
_sampler::$(DynamicPPL.AbstractSampler),
345359
_context::$(DynamicPPL.AbstractContext),
360+
_logps
346361
)
347362
$unwrap_data_expr
348363
$main_body

src/context_implementations.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3636
end
3737

3838
"""
39-
tilde_assume(ctx, sampler, right, vn, inds, vi)
39+
tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
4040
4141
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42-
accumulate the log probability, and return the sampled value.
42+
accumulate the log probability in `logps` (separately for each thread), and return the
43+
sampled value.
4344
4445
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4546
"""
46-
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47+
function tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
4748
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
48-
acclogp!(vi, logp)
49+
logps[Threads.threadid()] += logp
4950
return value
5051
end
5152

@@ -75,28 +76,29 @@ end
7576
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7677
7778
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
78-
accumulate the log probability, and return the observed value.
79+
accumulate the log probability in `logps` (separately for each thread), and return the
80+
observed value.
7981
8082
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
8183
and indices; if needed, these can be accessed through this function, though.
8284
"""
83-
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
85+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
8486
logp = tilde(ctx, sampler, right, left, vi)
85-
acclogp!(vi, logp)
87+
logps[Threads.threadid()] += logp
8688
return left
8789
end
8890

8991
"""
90-
tilde_observe(ctx, sampler, right, left, vi)
92+
tilde_observe(ctx, sampler, right, left, vi, logps)
9193
92-
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93-
observed value.
94+
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95+
(separately for each thread), and return the observed value.
9496
9597
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9698
"""
97-
function tilde_observe(ctx, sampler, right, left, vi)
99+
function tilde_observe(ctx, sampler, right, left, vi, logps)
98100
logp = tilde(ctx, sampler, right, left, vi)
99-
acclogp!(vi, logp)
101+
logps[Threads.threadid()] += logp
100102
return left
101103
end
102104

@@ -199,13 +201,14 @@ end
199201
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200202
201203
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
202-
model inputs), accumulate the log probability, and return the sampled value.
204+
model inputs), accumulate the log probability in `logps` (separately for each thread), and
205+
return the sampled value.
203206
204207
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
205208
"""
206-
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
209+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi, logps)
207210
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
208-
acclogp!(vi, logp)
211+
logps[Threads.threadid()] += logp
209212
return value
210213
end
211214

@@ -381,31 +384,32 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
381384
end
382385

383386
"""
384-
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
387+
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
385388
386389
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
387-
accumulate the log probability, and return the observed value.
390+
accumulate the log probability in `logps` (separately for each thread), and return the
391+
observed value.
388392
389393
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
390394
name and indices; if needed, these can be accessed through this function, though.
391395
"""
392-
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
396+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi, logps)
393397
logp = dot_tilde(ctx, sampler, right, left, vi)
394-
acclogp!(vi, logp)
398+
logps[Threads.threadid()] += logp
395399
return left
396400
end
397401

398402
"""
399-
dot_tilde_observe(ctx, sampler, right, left, vi)
403+
dot_tilde_observe(ctx, sampler, right, left, vi, logps)
400404
401405
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
402-
probability, and return the observed value.
406+
probability in `logps` (separately for each thread), and return the observed value.
403407
404408
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
405409
"""
406-
function dot_tilde_observe(ctx, sampler, right, left, vi)
410+
function dot_tilde_observe(ctx, sampler, right, left, vi, logps)
407411
logp = dot_tilde(ctx, sampler, right, left, vi)
408-
acclogp!(vi, logp)
412+
logps[Threads.threadid()] += logp
409413
return left
410414
end
411415

src/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ function getargs_tilde(expr::Expr)
3434
return
3535
end
3636

37+
"""
38+
initlogps(varinfo)
39+
40+
Return an `MVector` of length `Threads.nthreads()` filled with `zero(getlogp(varinfo))`.
41+
42+
It is used for accumulating the log probability in the model evaluation in a thread-safe
43+
way.
44+
"""
45+
function initlogps(varinfo)
46+
T = typeof(getlogp(varinfo))
47+
return zeros(StaticArrays.MVector{Threads.nthreads(),T})
48+
end
49+
3750
############################################
3851
# Julia 1.2 temporary fix - Julia PR 33303 #
3952
############################################

test/compiler.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,4 +580,39 @@ end
580580
model = demo()
581581
@test all(iszero(model()) for _ in 1:1000)
582582
end
583+
@testset "threading" begin
584+
@info "Peforming threading tests with $(Threads.nthreads()) threads"
585+
586+
x = rand(10_000)
587+
588+
@model function wthreads(x)
589+
x[1] ~ Normal(0, 1)
590+
Threads.@threads for i in 2:length(x)
591+
x[i] ~ Normal(x[i-1], 1)
592+
end
593+
end
594+
595+
vi = VarInfo()
596+
wthreads(x)(vi)
597+
lp_w_threads = getlogp(vi)
598+
599+
println("With threading:")
600+
@time wthreads(x)(vi)
601+
602+
@model function wothreads(x)
603+
x[1] ~ Normal(0, 1)
604+
for i in 2:length(x)
605+
x[i] ~ Normal(x[i-1], 1)
606+
end
607+
end
608+
609+
vi = VarInfo()
610+
wothreads(x)(vi)
611+
lp_wo_threads = getlogp(vi)
612+
613+
println("Without threading:")
614+
@time wothreads(x)(vi)
615+
616+
@test lp_w_threads lp_wo_threads
617+
end
583618
end

0 commit comments

Comments
 (0)