Skip to content

Commit 5e2e466

Browse files
ColCarrollspringcoil
authored andcommitted
Fix parallel sampling (#1481)
* Fix parallel sampling * Everything is good again
1 parent 954b01d commit 5e2e466

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

pymc3/sampling.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from joblib import Parallel, delayed
44
from numpy.random import randint, seed
5-
from numpy import shape, asarray
5+
import numpy as np
66

77
import pymc3 as pm
88
from .backends.base import merge_traces, BaseTrace, MultiTrace
@@ -276,18 +276,26 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
276276

277277

278278
def _make_parallel(arg, njobs):
279-
if not shape(arg):
279+
if not np.shape(arg):
280280
return [arg] * njobs
281281
return arg
282282

283283

284+
def _parallel_random_seed(random_seed, njobs):
285+
if random_seed == -1 and njobs > 1:
286+
max_int = np.iinfo(np.int32).max
287+
return [randint(max_int) for _ in range(njobs)]
288+
else:
289+
return _make_parallel(random_seed, njobs)
290+
291+
284292
def _mp_sample(**kwargs):
285293
njobs = kwargs.pop('njobs')
286294
chain = kwargs.pop('chain')
287295
random_seed = kwargs.pop('random_seed')
288296
start = kwargs.pop('start')
289297

290-
rseed = _make_parallel(random_seed, njobs)
298+
rseed = _parallel_random_seed(random_seed, njobs)
291299
start_vals = _make_parallel(start, njobs)
292300

293301
chains = list(range(chain, chain + njobs))
@@ -365,4 +373,4 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
365373
ppc[var.name].append(var.distribution.random(point=param,
366374
size=size))
367375

368-
return {k: asarray(v) for k, v in ppc.items()}
376+
return {k: np.asarray(v) for k, v in ppc.items()}

pymc3/tests/test_sampling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from itertools import combinations
12
import numpy as np
23
try:
34
import unittest.mock as mock # py3
@@ -33,6 +34,26 @@ def test_sample_does_not_set_seed(self):
3334
random_numbers.append(np.random.random())
3435
self.assertEqual(random_numbers[0], random_numbers[1])
3536

37+
def test_parallel_sample_does_not_reuse_seed(self):
38+
njobs = 4
39+
random_numbers = []
40+
draws = []
41+
for _ in range(2):
42+
np.random.seed(1) # seeds in other processes don't effect main process
43+
with self.model:
44+
trace = pm.sample(100, njobs=njobs)
45+
# numpy thread mentioned race condition. might as well check none are equal
46+
for first, second in combinations(range(njobs), 2):
47+
first_chain = trace.get_values('x', chains=first)
48+
second_chain = trace.get_values('x', chains=second)
49+
self.assertFalse((first_chain == second_chain).all())
50+
draws.append(trace.get_values('x'))
51+
random_numbers.append(np.random.random())
52+
53+
# Make sure future random processes aren't effected by this
54+
self.assertEqual(*random_numbers)
55+
self.assertTrue((draws[0] == draws[1]).all())
56+
3657
def test_sample(self):
3758
test_njobs = [1]
3859
with self.model:

0 commit comments

Comments
 (0)