Skip to content

Commit 24f88b1

Browse files
authored
Merge pull request #95 from TuringLang/dw/threadid
Remove use of `threadid`
2 parents fe972e8 + bb7ced2 commit 24f88b1

File tree

3 files changed

+29
-33
lines changed

3 files changed

+29
-33
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
version:
16-
- '1.0'
16+
- '1.3'
1717
- '1'
1818
- nightly
1919
os:
@@ -31,7 +31,7 @@ jobs:
3131
arch: x86
3232
- os: macOS-latest
3333
arch: x86
34-
- version: '1.0'
34+
- version: '1.3'
3535
num_threads: 2
3636
include:
3737
- version: '1'
@@ -45,16 +45,7 @@ jobs:
4545
with:
4646
version: ${{ matrix.version }}
4747
arch: ${{ matrix.arch }}
48-
- uses: actions/cache@v1
49-
env:
50-
cache-name: cache-artifacts
51-
with:
52-
path: ~/.julia/artifacts
53-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
54-
restore-keys: |
55-
${{ runner.os }}-test-${{ env.cache-name }}-
56-
${{ runner.os }}-test-
57-
${{ runner.os }}-
48+
- uses: julia-actions/cache@v1
5849
- uses: julia-actions/julia-buildpkg@latest
5950
- uses: julia-actions/julia-runtest@latest
6051
env:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "3.2.1"
6+
version = "3.2.2"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
@@ -25,7 +25,7 @@ ProgressLogging = "0.1"
2525
StatsBase = "0.32, 0.33"
2626
TerminalLoggers = "0.1"
2727
Transducers = "0.4.30"
28-
julia = "1"
28+
julia = "1.3"
2929

3030
[extras]
3131
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"

src/sample.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,15 @@ function mcmcsample(
298298
end
299299

300300
# Copy the random number generator, model, and sample for each thread
301-
# NOTE: As of May 17, 2020, this relies on Julia's thread scheduling functionality
302-
# that distributes a for loop into equal-sized blocks and allocates them
303-
# to each thread. If this changes, we may need to rethink things here.
301+
nchunks = min(nchains, Threads.nthreads())
302+
chunksize = cld(nchains, nchunks)
304303
interval = 1:min(nchains, Threads.nthreads())
305304
rngs = [deepcopy(rng) for _ in interval]
306305
models = [deepcopy(model) for _ in interval]
307306
samplers = [deepcopy(sampler) for _ in interval]
308307

309-
# Create a seed for each chain using the provided random number generator.
310-
seeds = rand(rng, UInt, nchains)
308+
# Create a seed for each chunk using the provided random number generator.
309+
seeds = rand(rng, UInt, nchunks)
311310

312311
# Set up a chains vector.
313312
chains = Vector{Any}(undef, nchains)
@@ -340,20 +339,26 @@ function mcmcsample(
340339

341340
Distributed.@async begin
342341
try
343-
Threads.@threads for i in 1:nchains
344-
# Obtain the ID of the current thread.
345-
id = Threads.threadid()
346-
347-
# Seed the thread-specific random number generator with the pre-made seed.
348-
subrng = rngs[id]
349-
Random.seed!(subrng, seeds[i])
350-
351-
# Sample a chain and save it to the vector.
352-
chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
353-
progress = false, kwargs...)
354-
355-
# Update the progress bar.
356-
progress && put!(channel, true)
342+
Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers)
343+
Threads.@spawn begin
344+
# Seed the chunk-specific random number generator with the pre-made seed.
345+
Random.seed!(_rng, seed)
346+
347+
chainidxs = if i == nchunks
348+
((i - 1) * chunksize + 1):nchains
349+
else
350+
((i - 1) * chunksize + 1):(i * chunksize)
351+
end
352+
353+
for chainidx in chainidxs
354+
# Sample a chain and save it to the vector.
355+
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
356+
progress = false, kwargs...)
357+
358+
# Update the progress bar.
359+
progress && put!(channel, true)
360+
end
361+
end
357362
end
358363
finally
359364
# Stop updating the progress bar.

0 commit comments

Comments
 (0)