Skip to content

Commit 4f1a291

Browse files
authored
add parallel sampling to SMC (#3367)
* add parallel sampling to SMC * fix test * fix test, second attempt * set parallel=False
1 parent 5e1bc75 commit 4f1a291

File tree

3 files changed

+305
-251
lines changed

3 files changed

+305
-251
lines changed

pymc3/sampling.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
247247
number of draws.
248248
cores : int
249249
The number of chains to run in parallel. If `None`, set to the number of CPUs in the
250-
system, but at most 4 (for 'SMC' defaults to 1). Keep in mind that some chains might
251-
themselves be multithreaded via openmp or BLAS. In those cases it might be faster to set
252-
this to 1.
250+
system, but at most 4 (for 'SMC' ignored if `pm.SMC(parallel=False)`. Keep in mind that
251+
some chains might themselves be multithreaded via openmp or BLAS. In those cases it might
252+
be faster to set this to 1.
253253
tune : int
254254
Number of iterations to tune, defaults to 500. Ignored when using 'SMC'. Samplers adjust
255255
the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition
@@ -319,25 +319,27 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
319319
nuts_kwargs = kwargs.pop('nuts_kwargs', None)
320320
if nuts_kwargs is not None:
321321
warnings.warn("The nuts_kwargs argument has been deprecated. Pass step "
322-
"method arguments directly to sample instead",
323-
DeprecationWarning)
322+
"method arguments directly to sample instead",
323+
DeprecationWarning)
324324
kwargs.update(nuts_kwargs)
325325
step_kwargs = kwargs.pop('step_kwargs', None)
326326
if step_kwargs is not None:
327327
warnings.warn("The step_kwargs argument has been deprecated. Pass step "
328-
"method arguments directly to sample instead",
329-
DeprecationWarning)
328+
"method arguments directly to sample instead",
329+
DeprecationWarning)
330330
kwargs.update(step_kwargs)
331331

332+
if cores is None:
333+
cores = min(4, _cpu_count())
334+
332335
if isinstance(step, pm.step_methods.smc.SMC):
333336
trace = smc.sample_smc(draws=draws,
334337
step=step,
338+
cores=cores,
335339
progressbar=progressbar,
336340
model=model,
337341
random_seed=random_seed)
338342
else:
339-
if cores is None:
340-
cores = min(4, _cpu_count())
341343
if 'njobs' in kwargs:
342344
cores = kwargs['njobs']
343345
warnings.warn(
@@ -423,12 +425,12 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
423425
'random_seed': random_seed,
424426
'live_plot': live_plot,
425427
'live_plot_kwargs': live_plot_kwargs,
426-
'cores': cores,}
428+
'cores': cores, }
427429

428430
sample_args.update(kwargs)
429431

430432
has_population_samplers = np.any([isinstance(m, arraystep.PopulationArrayStepShared)
431-
for m in (step.methods if isinstance(step, CompoundStep) else [step])])
433+
for m in (step.methods if isinstance(step, CompoundStep) else [step])])
432434

433435
parallel = cores > 1 and chains > 1 and not has_population_samplers
434436
if parallel:
@@ -1127,7 +1129,7 @@ def sample_ppc(*args, **kwargs):
11271129

11281130

11291131
def sample_posterior_predictive_w(traces, samples=None, models=None, weights=None,
1130-
random_seed=None, progressbar=True):
1132+
random_seed=None, progressbar=True):
11311133
"""Generate weighted posterior predictive samples from a list of models and
11321134
a list of traces according to a set of weights.
11331135
@@ -1306,7 +1308,8 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
13061308
elif is_transformed_name(var_name):
13071309
untransformed = get_untransformed_name(var_name)
13081310
if untransformed in data:
1309-
prior[var_name] = model[untransformed].transformation.forward_val(data[untransformed])
1311+
prior[var_name] = model[untransformed].transformation.forward_val(
1312+
data[untransformed])
13101313
return prior
13111314

13121315

pymc3/step_methods/smc.py

Lines changed: 87 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import theano
66
import pymc3 as pm
77
from tqdm import tqdm
8+
import multiprocessing as mp
89

910
from .arraystep import metrop_select
1011
from .metropolis import MultivariateNormalProposal
@@ -43,6 +44,9 @@ class SMC:
4344
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
4445
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
4546
It should be between 0 and 1.
47+
parallel : bool
48+
Distribute computations across cores if the number of cores is larger than 1
49+
(see pm.sample() for details). Defaults to True.
4650
model : :class:`pymc3.Model`
4751
Optional model for sampling step. Defaults to None (taken from context).
4852
@@ -68,6 +72,7 @@ def __init__(
6872
tune_scaling=True,
6973
tune_steps=True,
7074
threshold=0.5,
75+
parallel=True,
7176
):
7277

7378
self.n_steps = n_steps
@@ -77,9 +82,10 @@ def __init__(
7782
self.tune_scaling = tune_scaling
7883
self.tune_steps = tune_steps
7984
self.threshold = threshold
85+
self.parallel = parallel
8086

8187

82-
def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed=-1):
88+
def sample_smc(draws=5000, step=None, cores=None, progressbar=False, model=None, random_seed=-1):
8389
"""
8490
Sequential Monte Carlo sampling
8591
@@ -90,6 +96,8 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
9096
independent Markov Chains. Defaults to 5000.
9197
step : :class:`SMC`
9298
SMC initialization object
99+
cores : int
100+
The number of chains to run in parallel.
93101
progressbar : bool
94102
Flag for displaying a progress bar
95103
model : pymc3 Model
@@ -102,9 +110,10 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
102110
if random_seed != -1:
103111
np.random.seed(random_seed)
104112

105-
beta = 0.
113+
beta = 0.0
106114
stage = 0
107-
acc_rate = 1.
115+
accepted = 0
116+
acc_rate = 1.0
108117
proposed = draws * step.n_steps
109118
model.marginal_likelihood = 1
110119
variables = inputvars(model.vars)
@@ -138,53 +147,95 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
138147
if step.tune_scaling:
139148
step.scaling = _tune(acc_rate)
140149
if step.tune_steps:
141-
acc_rate = max(1. / proposed, acc_rate)
150+
acc_rate = max(1.0 / proposed, acc_rate)
142151
step.n_steps = min(
143152
step.max_steps, 1 + int(np.log(step.p_acc_rate) / np.log(1 - acc_rate))
144153
)
145154

146-
pm._log.info(
147-
"Stage: {:d} Beta: {:f} Steps: {:d}".format(stage, beta, step.n_steps, acc_rate)
148-
)
155+
pm._log.info("Stage: {:d} Beta: {:.3f} Steps: {:d}".format(stage, beta, step.n_steps))
149156
# Apply Metropolis kernel (mutation)
150157
proposed = draws * step.n_steps
151-
accepted = 0.
152158
priors = np.array([prior_logp(sample) for sample in posterior]).squeeze()
153-
tempered_post = priors + likelihoods * beta
154-
for draw in tqdm(range(draws), disable=not progressbar):
155-
old_tempered_post = tempered_post[draw]
156-
q_old = posterior[draw]
157-
deltas = np.squeeze(proposal(step.n_steps) * step.scaling)
158-
for n_step in range(step.n_steps):
159-
delta = deltas[n_step]
160-
161-
if any_discrete:
162-
if all_discrete:
163-
delta = np.round(delta, 0).astype("int64")
164-
q_old = q_old.astype("int64")
165-
q_new = (q_old + delta).astype("int64")
166-
else:
167-
delta[discrete] = np.round(delta[discrete], 0)
168-
q_new = floatX(q_old + delta)
169-
else:
170-
q_new = floatX(q_old + delta)
171-
172-
new_tempered_post = prior_logp(q_new) + likelihood_logp(q_new)[0] * beta
173-
174-
q_old, accept = metrop_select(new_tempered_post - old_tempered_post, q_new, q_old)
175-
if accept:
176-
accepted += 1
177-
posterior[draw] = q_old
178-
old_tempered_post = new_tempered_post
179-
180-
acc_rate = accepted / proposed
159+
tempered_logp = priors + likelihoods * beta
160+
deltas = np.squeeze(proposal(step.n_steps) * step.scaling)
161+
162+
parameters = (
163+
proposal,
164+
step.scaling,
165+
accepted,
166+
any_discrete,
167+
all_discrete,
168+
discrete,
169+
step.n_steps,
170+
prior_logp,
171+
likelihood_logp,
172+
beta,
173+
)
174+
175+
if step.parallel and cores > 1:
176+
pool = mp.Pool(processes=cores)
177+
results = pool.starmap(
178+
_metrop_kernel,
179+
[(posterior[draw], tempered_logp[draw], *parameters) for draw in range(draws)],
180+
)
181+
else:
182+
results = [
183+
_metrop_kernel(posterior[draw], tempered_logp[draw], *parameters)
184+
for draw in tqdm(range(draws), disable=not progressbar)
185+
]
186+
187+
posterior, acc_list = zip(*results)
188+
posterior = np.array(posterior)
189+
acc_rate = sum(acc_list) / proposed
181190
stage += 1
182191

183192
trace = _posterior_to_trace(posterior, variables, model, var_info)
184193

185194
return trace
186195

187196

197+
def _metrop_kernel(
198+
q_old,
199+
old_tempered_logp,
200+
proposal,
201+
scaling,
202+
accepted,
203+
any_discrete,
204+
all_discrete,
205+
discrete,
206+
n_steps,
207+
prior_logp,
208+
likelihood_logp,
209+
beta,
210+
):
211+
"""
212+
Metropolis kernel
213+
"""
214+
deltas = np.squeeze(proposal(n_steps) * scaling)
215+
for n_step in range(n_steps):
216+
delta = deltas[n_step]
217+
218+
if any_discrete:
219+
if all_discrete:
220+
delta = np.round(delta, 0).astype("int64")
221+
q_old = q_old.astype("int64")
222+
q_new = (q_old + delta).astype("int64")
223+
else:
224+
delta[discrete] = np.round(delta[discrete], 0)
225+
q_new = floatX(q_old + delta)
226+
else:
227+
q_new = floatX(q_old + delta)
228+
229+
new_tempered_logp = prior_logp(q_new) + likelihood_logp(q_new)[0] * beta
230+
231+
q_old, accept = metrop_select(new_tempered_logp - old_tempered_logp, q_new, q_old)
232+
if accept:
233+
accepted += 1
234+
old_tempered_logp = new_tempered_logp
235+
236+
return q_old, accepted
237+
238+
188239
def _initial_population(draws, model, variables):
189240
"""
190241
Create an initial population from the prior

0 commit comments

Comments
 (0)