Skip to content

Commit 3cf6525

Browse files
committed
add adamax; add --xtol and --ftol
1 parent ba38d1b commit 3cf6525

File tree

4 files changed

+84
-63
lines changed

4 files changed

+84
-63
lines changed

smcpp/analysis.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(self, files, args):
3737
args.polarization_error = 0.5
3838
if args.polarization_error > 0.:
3939
logger.info("Polarization error p=%f", args.polarization_error)
40-
if args.factr:
41-
args.solver_args['factr'] = args.factr
4240
# Data-related stuff
4341
self._load_data(files)
4442
self._validate_data()
@@ -109,7 +107,7 @@ def _perform_thinning(self, thinning):
109107
elif np.any(ns > 0):
110108
logger.warn("Not thinning yet undistinguished lineages are present")
111109

112-
def _normalize_data(self, length_cutoff, no_filter):
110+
def _normalize_data(self, length_cutoff, filter):
113111
## break up long spans
114112
self._contigs, attrs = estimation_tools.break_long_spans(self._contigs, length_cutoff)
115113
if not attrs:
@@ -128,7 +126,7 @@ def _normalize_data(self, length_cutoff, no_filter):
128126
var = np.average((het - avg) ** 2, weights=w) * (n / (n - 1.))
129127
sd = np.sqrt(var)
130128
logger.debug("Average/sd het:%f(%f)", avg, sd)
131-
if not no_filter:
129+
if filter:
132130
logger.debug("Keeping contigs within +-3 s.d. of mean")
133131
logger.debug("Average heterozygosity (derived / total bases) by data set (* = dropped)")
134132
ci = 0
@@ -139,7 +137,7 @@ def _normalize_data(self, length_cutoff, no_filter):
139137
for attr in attrs[key]:
140138
het = attr[-1]
141139
mytpl = tpl
142-
if no_filter or abs(het - avg) <= 3 * sd:
140+
if not filter or abs(het - avg) <= 3 * sd:
143141
new_contigs.append(self._contigs[ci])
144142
else:
145143
mytpl += " *"
@@ -282,7 +280,7 @@ def __init__(self, files, args):
282280
self.rescale(args.tK),
283281
knot_spans, args.offset))
284282
# Perform initial filtering for weird contigs
285-
self._normalize_data(args.length_cutoff, args.no_filter)
283+
self._normalize_data(args.length_cutoff, args.filter)
286284

287285
# Initialize members
288286
self._init_parameters(args.theta, args.rho)
@@ -295,7 +293,8 @@ def __init__(self, files, args):
295293
self._init_optimizer(args, files, args.outdir,
296294
1, # set block-size to knots
297295
"L-BFGS-B", # TNC tends to overfit for initial pass
298-
args.tolerance, learn_rho=False)
296+
args.xtol, args.ftol,
297+
learn_rho=False)
299298
self._optimizer.run(1)
300299

301300
# Thin the data
@@ -305,7 +304,7 @@ def __init__(self, files, args):
305304
self._init_hidden_states(args.prior_model, args.M)
306305
self._init_inference_manager(args.polarization_error)
307306
self._init_optimizer(args, files, args.outdir, args.blocks,
308-
args.algorithm, args.tolerance, learn_rho=True)
307+
args.algorithm, args.xtol, args.ftol, learn_rho=True)
309308

310309
def _init_parameters(self, theta=None, rho=None):
311310
## Set theta and rho to their default parameters
@@ -363,16 +362,16 @@ def _init_model(self, pieces, N0, t1, tK, spline_class):
363362
mods[-1][-1] = y0
364363
self._model = SMCTwoPopulationModel(mods[0], mods[1], split)
365364

366-
def _init_optimizer(self, args, files, outdir, blocks, algorithm, tolerance, learn_rho):
365+
def _init_optimizer(self, args, files, outdir, blocks, algorithm, xtol, ftol, learn_rho):
367366
if self.npop == 1:
368367
self._optimizer = optimizer.SMCPPOptimizer(
369-
self, algorithm, tolerance, blocks, args.solver_args)
368+
self, algorithm, xtol, ftol, blocks, args.solver_args)
370369
# Also optimize knots in 1 pop case. Not yet implemented
371370
# for two pop case.
372371
# self._optimizer.register(optimizer.KnotOptimizer())
373372
elif self.npop == 2:
374373
self._optimizer = optimizer.TwoPopulationOptimizer(
375-
self, algorithm, tolerance, blocks, args.solver_args)
374+
self, algorithm, xtol, ftol, blocks, args.solver_args)
376375
smax = np.sum(self._model.distinguished_model.s)
377376
self._optimizer.register(
378377
optimizer.ParameterOptimizer("split", (0., smax), "model"))
@@ -397,26 +396,26 @@ def __init__(self, files, args):
397396

398397
self._hidden_states = np.array([0., np.inf])
399398
self._init_inference_manager(False)
400-
self._init_optimizer(args, files, args.outdir, args.algorithm, args.tolerance, args.blocks, False)
399+
self._init_optimizer(args, files, args.outdir, args.algorithm, args.xtol, args.ftol, args.blocks, False)
401400
# Hack to only estimate split time.
402401
self._optimizer.run(1)
403402

404403
# After inferring initial split time, thin
405404
self._perform_thinning(args.thinning)
406-
self._normalize_data(args.length_cutoff, args.no_filter)
405+
self._normalize_data(args.length_cutoff, args.filter)
407406

408407
self._init_hidden_states(args.pop1, args.M)
409408
self._init_inference_manager(False)
410-
self._init_optimizer(args, files, args.outdir, args.algorithm, args.tolerance, args.blocks)
409+
self._init_optimizer(args, files, args.outdir, args.algorithm, args.xtol, args.ftol, args.blocks)
411410

412411
def _validate_data(self):
413412
BaseAnalysis._validate_data(self)
414413
if not any(c.npop == 2 for c in self._contigs):
415414
logger.error("Data contains no joint frequency spectrum information. Split estimation is impossible.")
416415
sys.exit(1)
417416

418-
def _init_optimizer(self, args, files, outdir, algorithm, tolerance, blocks, save=True):
419-
self._optimizer = optimizer.TwoPopulationOptimizer(self, algorithm, tolerance, blocks, args.solver_args)
417+
def _init_optimizer(self, args, files, outdir, algorithm, xtol, ftol, blocks, save=True):
418+
self._optimizer = optimizer.TwoPopulationOptimizer(self, algorithm, xtol, ftol, blocks, args.solver_args)
420419
smax = np.sum(self._model.distinguished_model.s)
421420
self._optimizer.register(optimizer.ParameterOptimizer("split", (0., smax), "model"))
422421
if save:

smcpp/commands/command.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ def add_common_estimation_args(parser):
4747
help="omit sequences < cutoff. default: 10000", default=10000, type=int)
4848
data.add_argument('--thinning', help="only emit full SFS every <k>th site. default: 500 * n.",
4949
default=None, type=int, metavar="k")
50-
data.add_argument('--no-filter', help="do not drop contigs with extreme heterozygosity. "
51-
"(not recommended unless data set is small)",
52-
action="store_true", default=False)
50+
data.add_argument('--filter', help=argparse.SUPPRESS, action="store_true", default=False)
5351

5452
optimizer = parser.add_argument_group("Optimization parameters")
5553
optimizer.add_argument(
@@ -60,13 +58,16 @@ def add_common_estimation_args(parser):
6058
default="L-BFGS-B", help=argparse.SUPPRESS)
6159
optimizer.add_argument('--blocks', type=int,
6260
help="number of coordinate ascent blocks. default: min(4, K)")
63-
optimizer.add_argument('--factr', type=float,
64-
default=1e-9, help=argparse.SUPPRESS)
61+
optimizer.add_argument("--ftol", type=float, default=1e-3,
62+
help="stopping criterion for relative improvement in loglik "
63+
"in EM algorithm. algorithm will terminate when "
64+
"|loglik' - loglik| / loglik < ftol")
65+
optimizer.add_argument('--xtol', type=float,
66+
default=.001,
67+
help=r"x tolerance for optimizer. "
68+
"optimizer will stop when |x' - x|_\infty < xtol")
6569
optimizer.add_argument('--regularization-penalty',
6670
type=float, help="regularization penalty", default=1.)
67-
optimizer.add_argument("--tolerance", type=float, default=1e-4,
68-
help="stopping criterion for relative improvement in loglik "
69-
"in EM algorithm")
7071
optimizer.add_argument('--Nmin', type=float,
7172
help="Lower bound on effective population size (in units of N0)",
7273
default=.01)

smcpp/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def K(self):
9292
return len(self.knots)
9393

9494
def randomize(self):
95-
self[:] += np.random.normal(0., .01, size=len(self[:]))
95+
self[:] += np.random.normal(0., .0001, size=len(self[:]))
9696

9797
@property
9898
def knots(self):

smcpp/optimizer.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,32 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24+
def AdaMax(f, x0, args, jac, bounds, alpha=0.1, b1=0.9, b2=0.999, eps=1e-3, **kwargs):
25+
assert jac
26+
bounds = np.array(bounds)
27+
def _f(x0):
28+
return tuple(q(x0, *args) for q in (f, jac))
29+
obj, grad = _f(x0)
30+
theta = x0.copy()
31+
t = 0
32+
mt = 0
33+
ut = 0
34+
while True:
35+
t += 1
36+
ft, gt = _f(theta)
37+
mt = b1 * mt + (1. - b1) * gt
38+
ut = np.maximum(b2 * ut, abs(gt))
39+
delta = -(alpha / (1. - b1 ** t)) * mt / ut
40+
if np.linalg.norm(delta) < eps:
41+
break
42+
theta = box_constrain(theta + delta, bounds)
43+
if 'callback' in kwargs:
44+
kwargs['callback'](theta)
45+
return scipy.optimize.OptimizeResult({'x': theta, 'fun': ft})
46+
47+
class ConvergedException(Exception):
48+
"Thrown when optimizer reaches stopping criterion."
49+
pass
2450

2551
class EMTerminationException(Exception):
2652
"Thrown when EM algorithm reaches stopping criterion."
@@ -31,11 +57,12 @@ class AbstractOptimizer(Observable):
3157
'''
3258
Abstract representation of the execution flow of the optimizer.
3359
'''
34-
def __init__(self, analysis, algorithm, tolerance, blocks, solver_args={}):
60+
def __init__(self, analysis, algorithm, ftol, xtol, blocks, solver_args={}):
3561
Observable.__init__(self)
3662
self._analysis = analysis
3763
self._algorithm = algorithm
38-
self._tolerance = tolerance
64+
self._ftol = ftol
65+
self._xtol = xtol
3966
self._blocks = blocks
4067
self._solver_args = solver_args
4168

@@ -66,6 +93,7 @@ def _f(self, x, analysis, coords, k=None):
6693
return ret
6794

6895
def _minimize(self, x0, coords, bounds):
96+
self._xk = None
6997
if os.environ.get("SMCPP_GRADIENT_CHECK", False):
7098
print("\n\ngradient check")
7199
y, dy = self._f(x0, self._analysis, coords)
@@ -74,12 +102,23 @@ def _minimize(self, x0, coords, bounds):
74102
y1, _ = self._f(x0, self._analysis, coords)
75103
print("***grad", i, y1, (y1 - y) * 1e8, dy[i])
76104
x0[i] -= 1e-8
77-
return minimize_proxy(self._f, x0,
78-
jac=True,
79-
args=(self._analysis, coords),
80-
bounds=bounds,
81-
options=self._solver_args,
82-
method=self._algorithm)
105+
try:
106+
if self._algorithm == "AdaMax":
107+
alg = AdaMax
108+
else:
109+
alg = self._algorithm
110+
res = scipy.optimize.minimize(self._f, x0,
111+
jac=True,
112+
args=(self._analysis, coords),
113+
bounds=bounds,
114+
options=self._solver_args,
115+
callback=self._callback,
116+
method=alg)
117+
return res
118+
except ConvergedException:
119+
logger.debug("Converged: |xk - xk_1| < %g", self._xtol)
120+
return scipy.optimize.OptimizeResult(
121+
{'x': self._xk, 'fun': self._f(self._xk, self._analysis, coords)[0]})
83122

84123
def run(self, niter):
85124
self.update_observers('begin')
@@ -108,6 +147,15 @@ def run(self, niter):
108147
# Conclude the optimization and perform any necessary callbacks.
109148
self.update_observers('optimization finished')
110149

150+
def _callback(self, xk):
151+
if self._xk is None:
152+
self._xk = xk
153+
return
154+
delta = max(abs(xk - self._xk))
155+
self._xk = xk
156+
if delta < self._xtol:
157+
raise ConvergedException()
158+
111159
def update_observers(self, *args, **kwargs):
112160
kwargs.update({
113161
'optimizer': self,
@@ -158,7 +206,7 @@ def update(self, message, *args, **kwargs):
158206
improvement = (self._old_loglik - ll) / self._old_loglik
159207
logger.info("New loglik: %f\t(old: %f [%f%%])",
160208
ll, self._old_loglik, 100. * improvement)
161-
tol = kwargs['optimizer']._tolerance
209+
tol = kwargs['optimizer']._ftol
162210
if improvement < 0:
163211
logger.warn("Loglik decreased")
164212
elif improvement < tol:
@@ -362,8 +410,8 @@ def write(x):
362410
class SMCPPOptimizer(AbstractOptimizer):
363411
'Model fitting for one population.'
364412

365-
def __init__(self, analysis, algorithm, tolerance, blocks, solver_args):
366-
AbstractOptimizer.__init__(self, analysis, algorithm, tolerance, blocks, solver_args)
413+
def __init__(self, analysis, algorithm, xtol, ftol, blocks, solver_args):
414+
AbstractOptimizer.__init__(self, analysis, algorithm, xtol, ftol, blocks, solver_args)
367415
observers = [
368416
HiddenStateOccupancyPrinter(),
369417
ProgressPrinter(),
@@ -405,32 +453,5 @@ def _coordinates(self):
405453
def _bounds(self, coords):
406454
return SMCPPOptimizer._bounds(self, coords[1])
407455

408-
AdaMaxResult = namedtuple('AdaMaxResult', 'x fun')
409-
410456
def box_constrain(x, bounds):
411457
return np.maximum(np.minimum(x, bounds[:, 1]), bounds[:, 0])
412-
413-
def AdaMax(f, x0, jac, args, bounds, alpha=0.0002, b1=0.9, b2=0.999, eps=1e-3, **kwargs):
414-
assert jac == True
415-
bounds = np.array(bounds)
416-
obj, grad = f(x0, *args)
417-
m0 = 0
418-
u0 = 0
419-
theta = x0.copy()
420-
t = 0
421-
mt = 0
422-
while True:
423-
t += 1
424-
ft, gt = f(theta, *args)
425-
mt = b1 * mt + (1. - b1) * gt
426-
ut = np.maximum(b2, np.abs(gt))
427-
delta = -(alpha / (1. - b1 ** t)) * mt / ut
428-
if np.linalg.norm(delta) < eps:
429-
break
430-
theta = box_constrain(theta + delta, bounds)
431-
return AdaMaxResult(x=theta, fun=ft)
432-
433-
def minimize_proxy(f, x0, *args, **kwargs):
434-
if kwargs['method'] == "AdaMax":
435-
return AdaMax(f, x0, *args, **kwargs)
436-
return scipy.optimize.minimize(f, x0, *args, **kwargs)

0 commit comments

Comments
 (0)