Skip to content

Commit 180458e

Browse files
authored
Merge pull request #89 from TuringLang/threads
Make accumulation of log probabilities thread-safe
2 parents 98e46ea + f4ab1ee commit 180458e

File tree

8 files changed

+122
-41
lines changed

8 files changed

+122
-41
lines changed

.github/workflows/DynamicPPL-CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
jobs:
1111
test:
1212
runs-on: ${{ matrix.os }}
13-
continue-on-error: ${{ matrix.version == 'nightly' }}
13+
continue-on-error: ${{ matrix.version == 'nightly' }}
1414
strategy:
1515
matrix:
1616
version:

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1314

1415
[compat]
1516
AbstractMCMC = "1.0"
1617
Bijectors = "0.5.2, 0.6"
1718
Distributions = "0.22, 0.23"
1819
MacroTools = "0.5.1"
20+
StaticArrays = "0.12.2"
1921
ZygoteRules = "0.2"
2022
julia = "1"
2123

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ using Bijectors
66
using MacroTools
77

88
import AbstractMCMC
9-
import Random
9+
import StaticArrays
1010
import ZygoteRules
1111

12+
import Random
13+
1214
import Base: Symbol,
1315
==,
1416
hash,

src/compiler.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
22
"Distributions."
33

4-
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
4+
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_logps)
55

66
"""
77
isassumption(expr)
@@ -234,24 +234,25 @@ function generate_tilde(left, right, args)
234234
$isassumption = $(DynamicPPL.isassumption(left))
235235
if $isassumption
236236
$left = $(DynamicPPL.tilde_assume)(
237-
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo, _logps)
238238
else
239239
$(DynamicPPL.tilde_observe)(
240-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
241241
end
242242
end
243243
end
244244

245245
return quote
246246
$(top...)
247-
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
247+
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
248+
_varinfo, _logps)
248249
end
249250
end
250251

251252
# If the LHS is a literal, it is always an observation
252253
return quote
253254
$(top...)
254-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
255+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo, _logps)
255256
end
256257
end
257258

@@ -278,25 +279,26 @@ function generate_dot_tilde(left, right, args)
278279
$isassumption = $(DynamicPPL.isassumption(left))
279280
if $isassumption
280281
$left .= $(DynamicPPL.dot_tilde_assume)(
281-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
282+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
282283
else
283284
$(DynamicPPL.dot_tilde_observe)(
284-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
285+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
285286
end
286287
end
287288
end
288289

289290
return quote
290291
$(top...)
291292
$left .= $(DynamicPPL.dot_tilde_assume)(
292-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
293+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
293294
end
294295
end
295296

296297
# If the LHS is a literal, it is always an observation
297298
return quote
298299
$(top...)
299-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
300+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo,
301+
_logps)
300302
end
301303
end
302304

@@ -333,16 +335,29 @@ function build_output(model_info)
333335
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
334336
end
335337

336-
@gensym(evaluator, generator)
338+
@gensym(evaluator, innerevaluator, generator)
337339
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
338340
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
339341

340342
return quote
341343
function $evaluator(
344+
model::$(DynamicPPL.Model),
345+
varinfo::$(DynamicPPL.VarInfo),
346+
sampler::$(DynamicPPL.AbstractSampler),
347+
context::$(DynamicPPL.AbstractContext),
348+
)
349+
logps = $(DynamicPPL.initlogps)(varinfo)
350+
result = $innerevaluator(model, varinfo, sampler, context, logps)
351+
$(DynamicPPL.acclogp!)(varinfo, $(Base.sum)(logps))
352+
return result
353+
end
354+
355+
function $innerevaluator(
342356
_model::$(DynamicPPL.Model),
343357
_varinfo::$(DynamicPPL.VarInfo),
344358
_sampler::$(DynamicPPL.AbstractSampler),
345359
_context::$(DynamicPPL.AbstractContext),
360+
_logps,
346361
)
347362
$unwrap_data_expr
348363
$main_body

src/context_implementations.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3636
end
3737

3838
"""
39-
tilde_assume(ctx, sampler, right, vn, inds, vi)
39+
tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
4040
4141
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42-
accumulate the log probability, and return the sampled value.
42+
accumulate the log probability in `logps` (separately for each thread), and return the
43+
sampled value.
4344
4445
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4546
"""
46-
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47+
function tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
4748
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
48-
acclogp!(vi, logp)
49+
logps[Threads.threadid()] += logp
4950
return value
5051
end
5152

@@ -75,28 +76,29 @@ end
7576
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7677
7778
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
78-
accumulate the log probability, and return the observed value.
79+
accumulate the log probability in `logps` (separately for each thread), and return the
80+
observed value.
7981
8082
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
8183
and indices; if needed, these can be accessed through this function, though.
8284
"""
83-
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
85+
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
8486
logp = tilde(ctx, sampler, right, left, vi)
85-
acclogp!(vi, logp)
87+
logps[Threads.threadid()] += logp
8688
return left
8789
end
8890

8991
"""
90-
tilde_observe(ctx, sampler, right, left, vi)
92+
tilde_observe(ctx, sampler, right, left, vi, logps)
9193
92-
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93-
observed value.
94+
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95+
(separately for each thread), and return the observed value.
9496
9597
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9698
"""
97-
function tilde_observe(ctx, sampler, right, left, vi)
99+
function tilde_observe(ctx, sampler, right, left, vi, logps)
98100
logp = tilde(ctx, sampler, right, left, vi)
99-
acclogp!(vi, logp)
101+
logps[Threads.threadid()] += logp
100102
return left
101103
end
102104

@@ -199,13 +201,14 @@ end
199201
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200202
201203
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
202-
model inputs), accumulate the log probability, and return the sampled value.
204+
model inputs), accumulate the log probability in `logps` (separately for each thread), and
205+
return the sampled value.
203206
204207
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
205208
"""
206-
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
209+
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi, logps)
207210
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
208-
acclogp!(vi, logp)
211+
logps[Threads.threadid()] += logp
209212
return value
210213
end
211214

@@ -381,31 +384,32 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
381384
end
382385

383386
"""
384-
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
387+
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
385388
386389
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
387-
accumulate the log probability, and return the observed value.
390+
accumulate the log probability in `logps` (separately for each thread), and return the
391+
observed value.
388392
389393
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
390394
name and indices; if needed, these can be accessed through this function, though.
391395
"""
392-
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
396+
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi, logps)
393397
logp = dot_tilde(ctx, sampler, right, left, vi)
394-
acclogp!(vi, logp)
398+
logps[Threads.threadid()] += logp
395399
return left
396400
end
397401

398402
"""
399-
dot_tilde_observe(ctx, sampler, right, left, vi)
403+
dot_tilde_observe(ctx, sampler, right, left, vi, logps)
400404
401405
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
402-
probability, and return the observed value.
406+
probability in `logps` (separately for each thread), and return the observed value.
403407
404408
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
405409
"""
406-
function dot_tilde_observe(ctx, sampler, right, left, vi)
410+
function dot_tilde_observe(ctx, sampler, right, left, vi, logps)
407411
logp = dot_tilde(ctx, sampler, right, left, vi)
408-
acclogp!(vi, logp)
412+
logps[Threads.threadid()] += logp
409413
return left
410414
end
411415

src/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ function getargs_tilde(expr::Expr)
3434
return
3535
end
3636

37+
"""
38+
initlogps(varinfo)
39+
40+
Return an `MVector` of length `Threads.nthreads()` filled with `zero(getlogp(varinfo))`.
41+
42+
It is used for accumulating the log probability in the model evaluation in a thread-safe
43+
way.
44+
"""
45+
function initlogps(varinfo)
46+
T = typeof(getlogp(varinfo))
47+
return zeros(StaticArrays.MVector{Threads.nthreads(),T})
48+
end
49+
3750
############################################
3851
# Julia 1.2 temporary fix - Julia PR 33303 #
3952
############################################

test/compiler.jl

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,13 @@ end
210210

211211
# Test use of internal names
212212
@model testmodel(x) = begin
213-
x[1] ~ Bernoulli(0.5)
213+
x[1] ~  Bernoulli(0.5)
214214
global varinfo_ = _varinfo
215215
global sampler_ = _sampler
216216
global model_ = _model
217217
global context_ = _context
218-
global lp = getlogp(_varinfo)
218+
global logps_ = _logps
219+
global lp = sum(_logps)
219220
return x
220221
end
221222
model = testmodel([1.0])
@@ -226,6 +227,15 @@ end
226227
@test model_ === model
227228
@test sampler_ === SampleFromPrior()
228229
@test context_ === DefaultContext()
230+
@test length(logps_) == Threads.nthreads()
231+
@test sum(logps_) == lp
232+
for i in 1:length(logps_)
233+
if i == Threads.threadid()
234+
@test logps_[i] == lp
235+
else
236+
@test iszero(logps_[i])
237+
end
238+
end
229239

230240
# test DPPL#61
231241
@model testmodel(z) = begin
@@ -240,7 +250,7 @@ end
240250
function makemodel(p)
241251
@model testmodel(x) = begin
242252
x[1] ~ Bernoulli(p)
243-
global lp = getlogp(_varinfo)
253+
global lp = sum(_logps)
244254
return x
245255
end
246256
return testmodel
@@ -580,4 +590,39 @@ end
580590
model = demo()
581591
@test all(iszero(model()) for _ in 1:1000)
582592
end
593+
@testset "threading" begin
594+
@info "Peforming threading tests with $(Threads.nthreads()) threads"
595+
596+
x = rand(10_000)
597+
598+
@model function wthreads(x)
599+
x[1] ~ Normal(0, 1)
600+
Threads.@threads for i in 2:length(x)
601+
x[i] ~ Normal(x[i-1], 1)
602+
end
603+
end
604+
605+
vi = VarInfo()
606+
wthreads(x)(vi)
607+
lp_w_threads = getlogp(vi)
608+
609+
println("With threading:")
610+
@time wthreads(x)(vi)
611+
612+
@model function wothreads(x)
613+
x[1] ~ Normal(0, 1)
614+
for i in 2:length(x)
615+
x[i] ~ Normal(x[i-1], 1)
616+
end
617+
end
618+
619+
vi = VarInfo()
620+
wothreads(x)(vi)
621+
lp_wo_threads = getlogp(vi)
622+
623+
println("Without threading:")
624+
@time wothreads(x)(vi)
625+
626+
@test lp_w_threads lp_wo_threads
627+
end
583628
end

test/varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,18 +471,18 @@ include(dir*"/test/test_utils/AllUtils.jl")
471471
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
472472
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()]
473473

474-
@inferred g_demo_f(vi1, hmc)
474+
@test_broken @inferred g_demo_f(vi1, hmc)
475475
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
476476
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])]
477477

478478
g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f)
479479
pg, hmc = g.state.samplers
480480
vi = empty!(TypedVarInfo(vi))
481-
@inferred g_demo_f(vi, SampleFromPrior())
481+
@test_broken @inferred g_demo_f(vi, SampleFromPrior())
482482
pg.state.vi = vi
483483
step!(Random.GLOBAL_RNG, g_demo_f, pg, 1)
484484
vi = pg.state.vi
485-
@inferred g_demo_f(vi, hmc)
485+
@test_broken @inferred g_demo_f(vi, hmc)
486486
@test vi.metadata.x.gids[1] == Set([pg.selector])
487487
@test vi.metadata.y.gids[1] == Set([pg.selector])
488488
@test vi.metadata.z.gids[1] == Set([pg.selector])

0 commit comments

Comments
 (0)