Skip to content

Commit d9a7fe6

Browse files
authored
Merge pull request #181 from biocircuits/parallel
Parallel support for bioscrape inference
2 parents 8c14123 + b319583 commit d9a7fe6

14 files changed

+619
-1230
lines changed

bioscrape/inference.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ cdef class StochasticStatesLikelihood(ModelLikelihood):
578578
def py_inference(Model = None, params_to_estimate = None, exp_data = None, initial_conditions = None,
579579
parameter_conditions = None, measurements = None, time_column = None, nwalkers = None,
580580
nsteps = None, init_seed = None, prior = None, sim_type = None, inference_type = 'emcee',
581-
method = 'mcmc', plot_show = True, **kwargs):
581+
method = 'mcmc', plot_show = True, parallel = None, **kwargs):
582582
"""
583583
User level interface for running bioscrape inference module.
584584
Args:
@@ -619,6 +619,10 @@ def py_inference(Model = None, params_to_estimate = None, exp_data = None, initi
619619
https://lmfit.github.io/lmfit-py/fitting.html#choosing-different-fitting-methods
620620
plot_show (bool): If set to `True`, bioscrape will try to display the generated plots from the inference run.
621621
If set to `False`, not plots will be shown.
622+
parallel (bool): If set to `True`, bioscrape will create a multiprocessing.Pool object
623+
and will be passed to emcee.EnsembleSampler for parallel
624+
processing. If set to `False`, multiprocessing will not be used.
625+
kwargs: Additional keyword arguments that are passed into the inference setup.
622626
Returns:
623627
for inference_type = "emcee":
624628
sampler, pid: A tuple consisting of the emcee.EnsembleSampler and the bioscrape pid object
@@ -649,12 +653,14 @@ def py_inference(Model = None, params_to_estimate = None, exp_data = None, initi
649653
pid.set_nsteps(nsteps)
650654
if sim_type is not None:
651655
pid.set_sim_type(sim_type)
656+
if parallel is not None:
657+
pid.set_parallel(parallel)
652658
if params_to_estimate is not None:
653659
pid.set_params_to_estimate(params_to_estimate)
654660
if prior is not None:
655661
pid.set_prior(prior)
656662
if inference_type == 'emcee' and method == 'mcmc':
657-
sampler = pid.run_mcmc(plot_show = plot_show, **kwargs)
663+
sampler = pid.run_mcmc(**kwargs)
658664
if plot_show:
659665
pid.plot_mcmc_results(sampler, **kwargs)
660666
return sampler, pid

bioscrape/inference_setup.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def __init__(self, **kwargs):
3939
self.cost_progress = []
4040
self.cost_params = []
4141
self.hmax = kwargs.get('hmax', None)
42+
self.parallel = kwargs.get('parallel', False)
43+
print("Received parallel as", self.parallel)
4244
if self.exp_data is not None:
4345
self.prepare_inference()
4446
self.setup_cost_function()
@@ -68,7 +70,8 @@ def __getstate__(self):
6870
self.debug,
6971
self.cost_progress,
7072
self.cost_params,
71-
self.hmax
73+
self.hmax,
74+
self.parallel
7275
)
7376

7477
def __setstate__(self, state):
@@ -95,6 +98,7 @@ def __setstate__(self, state):
9598
self.cost_progress = state[19]
9699
self.cost_params = state[20]
97100
self.hmax = state[21]
101+
self.parallel = state[22]
98102
if self.exp_data is not None:
99103
self.prepare_inference()
100104
self.setup_cost_function()
@@ -246,7 +250,7 @@ def set_exp_data(self, exp_data):
246250
self.exp_data = exp_data
247251
else:
248252
raise ValueError('exp_data must be either a Pandas dataframe or a list of dataframes.')
249-
return True
253+
return True
250254

251255
def set_norm_order(self, norm_order: int):
252256
'''
@@ -255,6 +259,13 @@ def set_norm_order(self, norm_order: int):
255259
self.norm_order = norm_order
256260
return True
257261

262+
def set_parallel(self, parallel: bool):
263+
'''
264+
Set the parallel flag to use parallel processing for MCMC
265+
'''
266+
self.parallel = parallel
267+
return True
268+
258269
def get_parameters(self):
259270
'''
260271
Returns the list of parameters to estimate that are set for the inference object
@@ -265,7 +276,7 @@ def run_mcmc(self, **kwargs):
265276
self.prepare_inference(**kwargs)
266277
sampler = self.run_emcee(**kwargs)
267278
return sampler
268-
279+
269280
def prepare_inference(self, **kwargs):
270281
timepoints = kwargs.get('timepoints')
271282
norm_order = kwargs.get('norm_order')
@@ -286,8 +297,9 @@ def prepare_inference(self, **kwargs):
286297
self.prepare_initial_conditions()
287298
self.prepare_parameter_conditions()
288299
self.LL_data = self.extract_data()
289-
290-
def prepare_initial_conditions(self, ):
300+
return
301+
302+
def prepare_initial_conditions(self):
291303
# Create initial conditions as required
292304
N = 1 if type(self.exp_data) is dict else len(self.exp_data)
293305
if type(self.initial_conditions) is dict:
@@ -328,7 +340,7 @@ def prepare_parameter_conditions(self):
328340
def extract_data(self):
329341
exp_data = self.exp_data
330342
# Get timepoints from given experimental data
331-
if isinstance(self.timepoints, (list, np.ndarray)):
343+
if isinstance(self.timepoints, (list, np.ndarray)) and self.debug:
332344
warnings.warn('Timepoints given by user, not using the data to extract the timepoints automatically.')
333345
M = len(self.measurements)# Number of measurements
334346
if type(exp_data) is list:
@@ -416,8 +428,8 @@ def setup_cost_function(self, **kwargs):
416428

417429
def cost_function(self, params):
418430
if self.pid_interface is None:
419-
raise RuntimeError("Must call InferenceSetup.setup_cost_function() before InferenceSetup.cost_function(params) can be used.")
420-
431+
raise RuntimeError("Must call InferenceSetup.setup_cost_function() \
432+
before InferenceSetup.cost_function(params) can be used.")
421433
cost_value = self.pid_interface.get_likelihood_function(params)
422434
self.cost_progress.append(cost_value)
423435
self.cost_params.append(params)
@@ -453,7 +465,6 @@ def seed_parameter_values(self, **kwargs):
453465
elif prior[0] == "log-uniform":
454466
a = np.log(prior[1])
455467
b = np.log(prior[2])
456-
457468
u = np.random.randn(self.nwalkers)*(b - a)+a
458469
p0[:, i] = np.exp(u)
459470
else:
@@ -492,13 +503,11 @@ def seed_parameter_values(self, **kwargs):
492503
def run_emcee(self, **kwargs):
493504
if kwargs.get("reuse_likelihood", False) is False:
494505
self.setup_cost_function(**kwargs)
495-
496506
progress = kwargs.get('progress')
497507
convergence_check = kwargs.get('convergence_check', False)
498508
convergence_diagnostics = kwargs.get('convergence_diagnostics', convergence_check)
499509
skip_initial_state_check = kwargs.get('skip_initial_state_check', False)
500510
progress = kwargs.get('progess', True)
501-
# threads = kwargs.get('threads', 1)
502511
fname_csv = kwargs.get('filename_csv', 'mcmc_results.csv')
503512
if 'results_filename' in kwargs:
504513
warnings.warn('The keyword results_filename is deprecated and'
@@ -513,17 +522,27 @@ def run_emcee(self, **kwargs):
513522
except:
514523
raise ImportError('emcee package not installed.')
515524
ndim = len(self.params_to_estimate)
516-
517525
p0 = self.seed_parameter_values(**kwargs)
518-
519526
assert p0.shape == (self.nwalkers, ndim)
520-
521-
pool = kwargs.get('pool', None)
522-
if printout: print("creating an ensemble sampler with multiprocessing pool=", pool)
523-
524-
sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.cost_function, pool = pool)
527+
if self.parallel:
528+
try:
529+
import multiprocessing
530+
pool = multiprocessing.Pool()
531+
if printout: print("Using {} cores for parallelization".format(multiprocessing.cpu_count()))
532+
except:
533+
pool = None
534+
raise ImportError('multiprocessing package not found. \
535+
Make sure to set parallel=False')
536+
else:
537+
pool = None
538+
if printout: print("creating an ensemble sampler without multiprocessing "\
539+
"pool. Set parallel=True to use parallel processing.")
540+
sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.cost_function, pool=pool)
525541
sampler.run_mcmc(p0, self.nsteps, progress=progress,
526542
skip_initial_state_check=skip_initial_state_check)
543+
if self.parallel:
544+
pool.close()
545+
pool.join()
527546
if convergence_check:
528547
self.autocorrelation_time = sampler.get_autocorr_time()
529548
if convergence_diagnostics:
@@ -547,7 +566,7 @@ def run_emcee(self, **kwargs):
547566
f.write(str(self.convergence_diagnostics))
548567
f.close()
549568
if printout: print("Results written to" + fname_csv + " and " + fname_txt)
550-
if printout: print('Successfully completed MCMC parameter identification procedure.'
569+
if printout: print('Successfully completed MCMC parameter identification procedure. '
551570
'Check the MCMC diagnostics to evaluate convergence.')
552571
return sampler
553572

0 commit comments

Comments
 (0)