|
2 | 2 |
|
3 | 3 | from joblib import Parallel, delayed
|
4 | 4 | from numpy.random import randint, seed
|
5 |
| -from numpy import shape, asarray |
| 5 | +import numpy as np |
6 | 6 |
|
7 | 7 | import pymc3 as pm
|
8 | 8 | from .backends.base import merge_traces, BaseTrace, MultiTrace
|
@@ -276,18 +276,26 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
|
276 | 276 |
|
277 | 277 |
|
278 | 278 | def _make_parallel(arg, njobs):
|
279 |
| - if not shape(arg): |
| 279 | + if not np.shape(arg): |
280 | 280 | return [arg] * njobs
|
281 | 281 | return arg
|
282 | 282 |
|
283 | 283 |
|
| 284 | +def _parallel_random_seed(random_seed, njobs): |
| 285 | + if random_seed == -1 and njobs > 1: |
| 286 | + max_int = np.iinfo(np.int32).max |
| 287 | + return [randint(max_int) for _ in range(njobs)] |
| 288 | + else: |
| 289 | + return _make_parallel(random_seed, njobs) |
| 290 | + |
| 291 | + |
284 | 292 | def _mp_sample(**kwargs):
|
285 | 293 | njobs = kwargs.pop('njobs')
|
286 | 294 | chain = kwargs.pop('chain')
|
287 | 295 | random_seed = kwargs.pop('random_seed')
|
288 | 296 | start = kwargs.pop('start')
|
289 | 297 |
|
290 |
| - rseed = _make_parallel(random_seed, njobs) |
| 298 | + rseed = _parallel_random_seed(random_seed, njobs) |
291 | 299 | start_vals = _make_parallel(start, njobs)
|
292 | 300 |
|
293 | 301 | chains = list(range(chain, chain + njobs))
|
@@ -365,4 +373,4 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
|
365 | 373 | ppc[var.name].append(var.distribution.random(point=param,
|
366 | 374 | size=size))
|
367 | 375 |
|
368 |
| - return {k: asarray(v) for k, v in ppc.items()} |
| 376 | + return {k: np.asarray(v) for k, v in ppc.items()} |
0 commit comments