Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions bioscrape/inference.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ cdef class StochasticStatesLikelihood(ModelLikelihood):
def py_inference(Model = None, params_to_estimate = None, exp_data = None, initial_conditions = None,
parameter_conditions = None, measurements = None, time_column = None, nwalkers = None,
nsteps = None, init_seed = None, prior = None, sim_type = None, inference_type = 'emcee',
method = 'mcmc', plot_show = True, **kwargs):
method = 'mcmc', plot_show = True, parallel = None, **kwargs):
"""
User level interface for running bioscrape inference module.
Args:
Expand Down Expand Up @@ -619,6 +619,10 @@ def py_inference(Model = None, params_to_estimate = None, exp_data = None, initi
https://lmfit.github.io/lmfit-py/fitting.html#choosing-different-fitting-methods
plot_show (bool): If set to `True`, bioscrape will try to display the generated plots from the inference run.
If set to `False`, not plots will be shown.
parallel (bool): If set to `True`, bioscrape will create a multiprocessing.Pool object
and will be passed to emcee.EnsembleSampler for parallel
processing. If set to `False`, multiprocessing will not be used.
kwargs: Additional keyword arguments that are passed into the inference setup.
Returns:
for inference_type = "emcee":
sampler, pid: A tuple consisting of the emcee.EnsembleSampler and the bioscrape pid object
Expand Down Expand Up @@ -649,12 +653,14 @@ def py_inference(Model = None, params_to_estimate = None, exp_data = None, initi
pid.set_nsteps(nsteps)
if sim_type is not None:
pid.set_sim_type(sim_type)
if parallel is not None:
pid.set_parallel(parallel)
if params_to_estimate is not None:
pid.set_params_to_estimate(params_to_estimate)
if prior is not None:
pid.set_prior(prior)
if inference_type == 'emcee' and method == 'mcmc':
sampler = pid.run_mcmc(plot_show = plot_show, **kwargs)
sampler = pid.run_mcmc(**kwargs)
if plot_show:
pid.plot_mcmc_results(sampler, **kwargs)
return sampler, pid
Expand Down
57 changes: 38 additions & 19 deletions bioscrape/inference_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, **kwargs):
self.cost_progress = []
self.cost_params = []
self.hmax = kwargs.get('hmax', None)
self.parallel = kwargs.get('parallel', False)
print("Received parallel as", self.parallel)
if self.exp_data is not None:
self.prepare_inference()
self.setup_cost_function()
Expand Down Expand Up @@ -68,7 +70,8 @@ def __getstate__(self):
self.debug,
self.cost_progress,
self.cost_params,
self.hmax
self.hmax,
self.parallel
)

def __setstate__(self, state):
Expand All @@ -95,6 +98,7 @@ def __setstate__(self, state):
self.cost_progress = state[19]
self.cost_params = state[20]
self.hmax = state[21]
self.parallel = state[22]
if self.exp_data is not None:
self.prepare_inference()
self.setup_cost_function()
Expand Down Expand Up @@ -246,7 +250,7 @@ def set_exp_data(self, exp_data):
self.exp_data = exp_data
else:
raise ValueError('exp_data must be either a Pandas dataframe or a list of dataframes.')
return True
return True

def set_norm_order(self, norm_order: int):
'''
Expand All @@ -255,6 +259,13 @@ def set_norm_order(self, norm_order: int):
self.norm_order = norm_order
return True

def set_parallel(self, parallel: bool):
'''
Set the parallel flag to use parallel processing for MCMC
'''
self.parallel = parallel
return True

def get_parameters(self):
'''
Returns the list of parameters to estimate that are set for the inference object
Expand All @@ -265,7 +276,7 @@ def run_mcmc(self, **kwargs):
self.prepare_inference(**kwargs)
sampler = self.run_emcee(**kwargs)
return sampler

def prepare_inference(self, **kwargs):
timepoints = kwargs.get('timepoints')
norm_order = kwargs.get('norm_order')
Expand All @@ -286,8 +297,9 @@ def prepare_inference(self, **kwargs):
self.prepare_initial_conditions()
self.prepare_parameter_conditions()
self.LL_data = self.extract_data()

def prepare_initial_conditions(self, ):
return

def prepare_initial_conditions(self):
# Create initial conditions as required
N = 1 if type(self.exp_data) is dict else len(self.exp_data)
if type(self.initial_conditions) is dict:
Expand Down Expand Up @@ -328,7 +340,7 @@ def prepare_parameter_conditions(self):
def extract_data(self):
exp_data = self.exp_data
# Get timepoints from given experimental data
if isinstance(self.timepoints, (list, np.ndarray)):
if isinstance(self.timepoints, (list, np.ndarray)) and self.debug:
warnings.warn('Timepoints given by user, not using the data to extract the timepoints automatically.')
M = len(self.measurements)# Number of measurements
if type(exp_data) is list:
Expand Down Expand Up @@ -416,8 +428,8 @@ def setup_cost_function(self, **kwargs):

def cost_function(self, params):
if self.pid_interface is None:
raise RuntimeError("Must call InferenceSetup.setup_cost_function() before InferenceSetup.cost_function(params) can be used.")

raise RuntimeError("Must call InferenceSetup.setup_cost_function() \
before InferenceSetup.cost_function(params) can be used.")
cost_value = self.pid_interface.get_likelihood_function(params)
self.cost_progress.append(cost_value)
self.cost_params.append(params)
Expand Down Expand Up @@ -453,7 +465,6 @@ def seed_parameter_values(self, **kwargs):
elif prior[0] == "log-uniform":
a = np.log(prior[1])
b = np.log(prior[2])

u = np.random.randn(self.nwalkers)*(b - a)+a
p0[:, i] = np.exp(u)
else:
Expand Down Expand Up @@ -492,13 +503,11 @@ def seed_parameter_values(self, **kwargs):
def run_emcee(self, **kwargs):
if kwargs.get("reuse_likelihood", False) is False:
self.setup_cost_function(**kwargs)

progress = kwargs.get('progress')
convergence_check = kwargs.get('convergence_check', False)
convergence_diagnostics = kwargs.get('convergence_diagnostics', convergence_check)
skip_initial_state_check = kwargs.get('skip_initial_state_check', False)
progress = kwargs.get('progess', True)
# threads = kwargs.get('threads', 1)
fname_csv = kwargs.get('filename_csv', 'mcmc_results.csv')
if 'results_filename' in kwargs:
warnings.warn('The keyword results_filename is deprecated and'
Expand All @@ -513,17 +522,27 @@ def run_emcee(self, **kwargs):
except:
raise ImportError('emcee package not installed.')
ndim = len(self.params_to_estimate)

p0 = self.seed_parameter_values(**kwargs)

assert p0.shape == (self.nwalkers, ndim)

pool = kwargs.get('pool', None)
if printout: print("creating an ensemble sampler with multiprocessing pool=", pool)

sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.cost_function, pool = pool)
if self.parallel:
try:
import multiprocessing
pool = multiprocessing.Pool()
if printout: print("Using {} cores for parallelization".format(multiprocessing.cpu_count()))
except:
pool = None
raise ImportError('multiprocessing package not found. \
Make sure to set parallel=False')
else:
pool = None
if printout: print("creating an ensemble sampler without multiprocessing "\
"pool. Set parallel=True to use parallel processing.")
sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.cost_function, pool=pool)
sampler.run_mcmc(p0, self.nsteps, progress=progress,
skip_initial_state_check=skip_initial_state_check)
if self.parallel:
pool.close()
pool.join()
if convergence_check:
self.autocorrelation_time = sampler.get_autocorr_time()
if convergence_diagnostics:
Expand All @@ -547,7 +566,7 @@ def run_emcee(self, **kwargs):
f.write(str(self.convergence_diagnostics))
f.close()
if printout: print("Results written to" + fname_csv + " and " + fname_txt)
if printout: print('Successfully completed MCMC parameter identification procedure.'
if printout: print('Successfully completed MCMC parameter identification procedure. '
'Check the MCMC diagnostics to evaluate convergence.')
return sampler

Expand Down
Loading
Loading