Skip to content

Commit 3ea1c96

Browse files
committed
[core] refactor & code cleanups
1 parent b792cf3 commit 3ea1c96

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

smcpp/analysis.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class BaseAnalysis:
2727
"Base class for analysis of population genetic data."
2828
def __init__(self, files, args):
2929
# Misc. parameter initialiations
30+
self._args = args
3031
self._N0 = .5e-4 / args.mu # .0001 = args.mu * 2 * N0
3132
self._penalty = 0.
3233
self._niter = args.em_iterations
@@ -45,6 +46,21 @@ def __init__(self, files, args):
4546
self._validate_data()
4647
self._recode_nonseg(args.nonseg_cutoff)
4748

49+
def _init_penalty(self):
50+
# Continue initializing
51+
# These will be updated after first pass
52+
self._hidden_states = {k: np.array([0., np.inf]) for k in self._populations}
53+
self._init_inference_manager(self._args.polarization_error)
54+
self._init_optimizer(None, self._args.algorithm, self._args.xtol, self._args.ftol)
55+
self.E_step()
56+
self._penalty = abs(self.Q()) * (10 ** -self._args.regularization_penalty)
57+
logger.debug("Auto-assigning regularization penalty lambda=%g", self._penalty)
58+
59+
def _init_optimizer(self, outdir, algorithm, xtol, ftol):
60+
self._optimizer = self._OPTIMIZER_CLS(self, algorithm, xtol, ftol)
61+
if outdir:
62+
self._optimizer.register(analysis_saver.AnalysisSaver(outdir))
63+
4864
def rescale(self, x):
4965
return x / (2. * self._N0)
5066

@@ -282,24 +298,16 @@ def __init__(self, files, args):
282298
hs = np.r_[[0.], self._knots, [np.inf]]
283299
self._init_model(self._N0, args.spline)
284300

285-
# Continue initializing
286-
# These will be updated after first pass
287-
self._hidden_states = {k: np.array([0., np.inf]) for k in self._populations}
288-
self._init_inference_manager(args.polarization_error)
289-
self._init_optimizer(args, None, args.blocks,
290-
args.algorithm, args.xtol, args.ftol,
291-
learn_rho=False)
292-
self.E_step()
293-
self._penalty = abs(self.Q()) * (10 ** -args.regularization_penalty)
294-
logger.debug("Auto-assigning regularization penalty lambda=%g", self._penalty)
301+
# Auto-assign regularization penalty using a heuristic
302+
self._init_penalty()
295303

296304
logger.debug("hidden states: %s", hs)
297305
self._hidden_states = {k: hs for k in self._populations}
298306
x = self.model[:]
299307
self._init_model(self._N0, args.spline)
300308
self.model[:] = x
301309
self._init_inference_manager(args.polarization_error)
302-
self._init_optimizer(args, args.outdir, args.blocks,
310+
self._init_optimizer(args.outdir,
303311
args.algorithm, args.xtol, args.ftol,
304312
learn_rho=args.r is None)
305313

@@ -347,12 +355,10 @@ def _init_model(self, N0, spline_class):
347355
mods[-1][:] = y0
348356
self._model = SMCTwoPopulationModel(mods[0], mods[1], split)
349357

350-
def _init_optimizer(self, args, outdir, blocks,
351-
algorithm, xtol, ftol, learn_rho):
352-
self._optimizer = SMCPPOptimizer(
353-
self, algorithm, xtol, ftol, blocks, args.solver_args)
354-
if outdir:
355-
self._optimizer.register(analysis_saver.AnalysisSaver(outdir))
358+
_OPTIMIZER_CLS = SMCPPOptimizer
359+
360+
def _init_optimizer(self, outdir, algorithm, xtol, ftol, learn_rho=False):
361+
super()._init_optimizer(outdir, algorithm, xtol, ftol)
356362
if learn_rho:
357363
rho_bounds = 2. * self._N0 * np.array([1e-10, 1e-5])
358364
self._optimizer.register(
@@ -370,10 +376,9 @@ def __init__(self, files, args):
370376
self._perform_thinning(args.thinning)
371377
self._normalize_data(args.length_cutoff, not args.no_filter)
372378
# Further initialization
373-
# keep separate hidden states for each distinguished type
379+
self._init_penalty()
374380
self._init_inference_manager(args.polarization_error)
375-
self._init_optimizer(args, args.outdir, args.blocks,
376-
args.algorithm, args.xtol, args.ftol, True)
381+
self._init_optimizer(args.outdir, args.algorithm, args.xtol, args.ftol)
377382

378383
def _validate_data(self):
379384
BaseAnalysis._validate_data(self)
@@ -382,16 +387,13 @@ def _validate_data(self):
382387
"information. Split estimation is impossible.")
383388
sys.exit(1)
384389

385-
def _init_optimizer(self, args, outdir, blocks, algorithm,
386-
xtol, ftol, save=True):
387-
self._optimizer = TwoPopulationOptimizer(
388-
self, algorithm, xtol, ftol, blocks, args.solver_args)
389-
self._optimizer.register(
390-
parameter_optimizer.ParameterOptimizer("split",
390+
_OPTIMIZER_CLS = TwoPopulationOptimizer
391+
392+
def _init_optimizer(self, outdir, algorithm, xtol, ftol):
393+
super()._init_optimizer(outdir, algorithm, xtol, ftol)
394+
self._optimizer.register(parameter_optimizer.ParameterOptimizer("split",
391395
(0., self._max_split),
392396
"model"))
393-
if save:
394-
self._optimizer.register(analysis_saver.AnalysisSaver(outdir))
395397

396398
def _init_model(self, pop1, pop2):
397399
d = json.load(open(pop1, "rt"))

smcpp/optimize/optimizers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ class AbstractOptimizer(Observable):
1919
'''
2020
Abstract representation of the execution flow of the optimizer.
2121
'''
22-
def __init__(self, analysis, algorithm, xtol, ftol, blocks, solver_args={}):
22+
def __init__(self, analysis, algorithm, xtol, ftol):
2323
Observable.__init__(self)
2424
self._analysis = analysis
2525
self._algorithm = algorithm
2626
self._ftol = ftol
2727
self._xtol = xtol
28-
self._blocks = blocks
29-
self._solver_args = solver_args
3028

3129
@abstractmethod
3230
def _coordinates(self, i):
@@ -187,8 +185,8 @@ def update_observers(self, *args, **kwargs):
187185
class SMCPPOptimizer(AbstractOptimizer):
188186
'Model fitting for one population.'
189187

190-
def __init__(self, analysis, algorithm, xtol, ftol, blocks, solver_args):
191-
AbstractOptimizer.__init__(self, analysis, algorithm, xtol, ftol, blocks, solver_args)
188+
def __init__(self, analysis, algorithm, xtol, ftol):
189+
AbstractOptimizer.__init__(self, analysis, algorithm, xtol, ftol)
192190
for cls in OptimizerPlugin.__subclasses__():
193191
try:
194192
if not cls.DISABLED:

0 commit comments

Comments
 (0)