@@ -39,6 +39,8 @@ def __init__(self, **kwargs):
3939 self .cost_progress = []
4040 self .cost_params = []
4141 self .hmax = kwargs .get ('hmax' , None )
42+ self .parallel = kwargs .get ('parallel' , False )
43+ print ("Received parallel as" , self .parallel )
4244 if self .exp_data is not None :
4345 self .prepare_inference ()
4446 self .setup_cost_function ()
@@ -68,7 +70,8 @@ def __getstate__(self):
6870 self .debug ,
6971 self .cost_progress ,
7072 self .cost_params ,
71- self .hmax
73+ self .hmax ,
74+ self .parallel
7275 )
7376
7477 def __setstate__ (self , state ):
@@ -95,6 +98,7 @@ def __setstate__(self, state):
9598 self .cost_progress = state [19 ]
9699 self .cost_params = state [20 ]
97100 self .hmax = state [21 ]
101+ self .parallel = state [22 ]
98102 if self .exp_data is not None :
99103 self .prepare_inference ()
100104 self .setup_cost_function ()
@@ -246,7 +250,7 @@ def set_exp_data(self, exp_data):
246250 self .exp_data = exp_data
247251 else :
248252 raise ValueError ('exp_data must be either a Pandas dataframe or a list of dataframes.' )
249- return True
253+ return True
250254
251255 def set_norm_order (self , norm_order : int ):
252256 '''
@@ -255,6 +259,13 @@ def set_norm_order(self, norm_order: int):
255259 self .norm_order = norm_order
256260 return True
257261
262+ def set_parallel (self , parallel : bool ):
263+ '''
264+ Set the parallel flag to use parallel processing for MCMC
265+ '''
266+ self .parallel = parallel
267+ return True
268+
258269 def get_parameters (self ):
259270 '''
260271 Returns the list of parameters to estimate that are set for the inference object
@@ -265,7 +276,7 @@ def run_mcmc(self, **kwargs):
265276 self .prepare_inference (** kwargs )
266277 sampler = self .run_emcee (** kwargs )
267278 return sampler
268-
279+
269280 def prepare_inference (self , ** kwargs ):
270281 timepoints = kwargs .get ('timepoints' )
271282 norm_order = kwargs .get ('norm_order' )
@@ -286,8 +297,9 @@ def prepare_inference(self, **kwargs):
286297 self .prepare_initial_conditions ()
287298 self .prepare_parameter_conditions ()
288299 self .LL_data = self .extract_data ()
289-
290- def prepare_initial_conditions (self , ):
300+ return
301+
302+ def prepare_initial_conditions (self ):
291303 # Create initial conditions as required
292304 N = 1 if type (self .exp_data ) is dict else len (self .exp_data )
293305 if type (self .initial_conditions ) is dict :
@@ -328,7 +340,7 @@ def prepare_parameter_conditions(self):
328340 def extract_data (self ):
329341 exp_data = self .exp_data
330342 # Get timepoints from given experimental data
331- if isinstance (self .timepoints , (list , np .ndarray )):
343+ if isinstance (self .timepoints , (list , np .ndarray )) and self . debug :
332344 warnings .warn ('Timepoints given by user, not using the data to extract the timepoints automatically.' )
333345 M = len (self .measurements )# Number of measurements
334346 if type (exp_data ) is list :
@@ -416,8 +428,8 @@ def setup_cost_function(self, **kwargs):
416428
417429 def cost_function (self , params ):
418430 if self .pid_interface is None :
419- raise RuntimeError ("Must call InferenceSetup.setup_cost_function() before InferenceSetup.cost_function(params) can be used." )
420-
431+ raise RuntimeError ("Must call InferenceSetup.setup_cost_function() \
432+ before InferenceSetup.cost_function(params) can be used." )
421433 cost_value = self .pid_interface .get_likelihood_function (params )
422434 self .cost_progress .append (cost_value )
423435 self .cost_params .append (params )
@@ -453,7 +465,6 @@ def seed_parameter_values(self, **kwargs):
453465 elif prior [0 ] == "log-uniform" :
454466 a = np .log (prior [1 ])
455467 b = np .log (prior [2 ])
456-
457468 u = np .random .randn (self .nwalkers )* (b - a )+ a
458469 p0 [:, i ] = np .exp (u )
459470 else :
@@ -492,13 +503,11 @@ def seed_parameter_values(self, **kwargs):
492503 def run_emcee (self , ** kwargs ):
493504 if kwargs .get ("reuse_likelihood" , False ) is False :
494505 self .setup_cost_function (** kwargs )
495-
496506 progress = kwargs .get ('progress' )
497507 convergence_check = kwargs .get ('convergence_check' , False )
498508 convergence_diagnostics = kwargs .get ('convergence_diagnostics' , convergence_check )
499509 skip_initial_state_check = kwargs .get ('skip_initial_state_check' , False )
500510 progress = kwargs .get ('progess' , True )
501- # threads = kwargs.get('threads', 1)
502511 fname_csv = kwargs .get ('filename_csv' , 'mcmc_results.csv' )
503512 if 'results_filename' in kwargs :
504513 warnings .warn ('The keyword results_filename is deprecated and'
@@ -513,17 +522,27 @@ def run_emcee(self, **kwargs):
513522 except :
514523 raise ImportError ('emcee package not installed.' )
515524 ndim = len (self .params_to_estimate )
516-
517525 p0 = self .seed_parameter_values (** kwargs )
518-
519526 assert p0 .shape == (self .nwalkers , ndim )
520-
521- pool = kwargs .get ('pool' , None )
522- if printout : print ("creating an ensemble sampler with multiprocessing pool=" , pool )
523-
524- sampler = emcee .EnsembleSampler (self .nwalkers , ndim , self .cost_function , pool = pool )
527+ if self .parallel :
528+ try :
529+ import multiprocessing
530+ pool = multiprocessing .Pool ()
531+ if printout : print ("Using {} cores for parallelization" .format (multiprocessing .cpu_count ()))
532+ except :
533+ pool = None
534+ raise ImportError ('multiprocessing package not found. \
535+ Make sure to set parallel=False' )
536+ else :
537+ pool = None
538+ if printout : print ("creating an ensemble sampler without multiprocessing " \
539+ "pool. Set parallel=True to use parallel processing." )
540+ sampler = emcee .EnsembleSampler (self .nwalkers , ndim , self .cost_function , pool = pool )
525541 sampler .run_mcmc (p0 , self .nsteps , progress = progress ,
526542 skip_initial_state_check = skip_initial_state_check )
543+ if self .parallel :
544+ pool .close ()
545+ pool .join ()
527546 if convergence_check :
528547 self .autocorrelation_time = sampler .get_autocorr_time ()
529548 if convergence_diagnostics :
@@ -547,7 +566,7 @@ def run_emcee(self, **kwargs):
547566 f .write (str (self .convergence_diagnostics ))
548567 f .close ()
549568 if printout : print ("Results written to" + fname_csv + " and " + fname_txt )
550- if printout : print ('Successfully completed MCMC parameter identification procedure.'
569+ if printout : print ('Successfully completed MCMC parameter identification procedure. '
551570 'Check the MCMC diagnostics to evaluate convergence.' )
552571 return sampler
553572
0 commit comments