Skip to content

Commit d3d081d

Browse files
authored
Merge pull request scipy#19668 from bjodah/tr-exact-subproblem-maxiter
ENH: optimize.minimize: add `subproblem_maxiter` option for `tr-exact`
2 parents 22eace9 + 38ec2e0 commit d3d081d

File tree

4 files changed

+293
-8
lines changed

4 files changed

+293
-8
lines changed

scipy/optimize/_trustregion.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _minimize_trust_region(fun, x0, args=(), jac=None, hess=None, hessp=None,
120120
max_trust_radius=1000.0, eta=0.15, gtol=1e-4,
121121
maxiter=None, disp=False, return_all=False,
122122
callback=None, inexact=True, workers=None,
123-
**unknown_options):
123+
subproblem_maxiter=None, **unknown_options):
124124
"""
125125
Minimization of scalar function of one or more variables using a
126126
trust-region algorithm.
@@ -150,6 +150,12 @@ def _minimize_trust_region(fun, x0, args=(), jac=None, hess=None, hessp=None,
150150
Only for 'trust-krylov', 'trust-ncg'.
151151
152152
.. versionadded:: 1.16.0
153+
subproblem_maxiter : int, optional
154+
Maximum number of iterations to perform per subproblem. Only affects
155+
trust-exact. Default is 25.
156+
157+
.. versionadded:: 1.17.0
158+
153159
154160
This function is called by the `minimize` function.
155161
It is not supposed to be called directly.
@@ -224,7 +230,12 @@ def hessp(x, p, *args):
224230
x = x0
225231
if return_all:
226232
allvecs = [x]
227-
m = subproblem(x, fun, jac, hess, hessp)
233+
234+
subproblem_init_kw = {}
235+
if hasattr(subproblem, 'MAXITER_DEFAULT'):
236+
subproblem_init_kw['maxiter'] = subproblem_maxiter
237+
238+
m = subproblem(x, fun, jac, hess, hessp, **subproblem_init_kw)
228239
k = 0
229240

230241
# search for the function min
@@ -246,7 +257,7 @@ def hessp(x, p, *args):
246257

247258
# define the local approximation at the proposed point
248259
x_proposed = x + p
249-
m_proposed = subproblem(x_proposed, fun, jac, hess, hessp)
260+
m_proposed = subproblem(x_proposed, fun, jac, hess, hessp, **subproblem_init_kw)
250261

251262
# evaluate the ratio defined in equation (4.4)
252263
actual_reduction = m.fun - m_proposed.fun

scipy/optimize/_trustregion_exact.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,18 @@ class IterativeSubproblem(BaseQuadraticSubproblem):
207207
# As recommended there it value is fixed in 0.01.
208208
UPDATE_COEFF = 0.01
209209

210+
# The subproblem may iterate infinitely for problematic
211+
# cases (see https://github.com/scipy/scipy/issues/12513).
212+
# When the `maxiter` setting is None, we need to apply a
213+
# default. An ad-hoc number (though tested quite extensively)
214+
# is 25, which is set below. To restore the old behavior (which
215+
# potentially hangs), this parameter may be changed to zero:
216+
MAXITER_DEFAULT = 25 # use np.inf for infinite number of iterations
217+
210218
EPS = np.finfo(float).eps
211219

212220
def __init__(self, x, fun, jac, hess, hessp=None,
213-
k_easy=0.1, k_hard=0.2):
221+
k_easy=0.1, k_hard=0.2, maxiter=None):
214222

215223
super().__init__(x, fun, jac, hess)
216224

@@ -232,6 +240,14 @@ def __init__(self, x, fun, jac, hess, hessp=None,
232240
self.k_easy = k_easy
233241
self.k_hard = k_hard
234242

243+
# ``maxiter`` optionally limits the number of iterations
244+
# the solve method may perform. Useful for poorly conditioned
245+
# problems which may otherwise hang.
246+
self.maxiter = self.MAXITER_DEFAULT if maxiter is None else maxiter
247+
if self.maxiter < 0:
248+
raise ValueError("maxiter must not be set to a negative number"
249+
", use np.inf to mean infinite.")
250+
235251
# Get Lapack function for cholesky decomposition.
236252
# The implemented SciPy wrapper does not return
237253
# the incomplete factorization needed by the method.
@@ -290,7 +306,7 @@ def solve(self, tr_radius):
290306
already_factorized = False
291307
self.niter = 0
292308

293-
while True:
309+
while self.niter < self.maxiter:
294310

295311
# Compute Cholesky factorization
296312
if already_factorized:
@@ -371,7 +387,7 @@ def solve(self, tr_radius):
371387

372388
# Update damping factor
373389
lambda_current = max(
374-
np.sqrt(lambda_lb * lambda_ub),
390+
np.sqrt(np.abs(lambda_lb * lambda_ub)),
375391
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb)
376392
)
377393

@@ -399,10 +415,10 @@ def solve(self, tr_radius):
399415
s_min, z_min = estimate_smallest_singular_value(U)
400416
step_len = tr_radius
401417

418+
p = step_len * z_min
402419
# Check stop criteria
403420
if (step_len**2 * s_min**2
404421
<= self.k_hard * lambda_current * tr_radius**2):
405-
p = step_len * z_min
406422
break
407423

408424
# Update uncertainty bounds
@@ -426,7 +442,7 @@ def solve(self, tr_radius):
426442

427443
# Update damping factor
428444
lambda_current = max(
429-
np.sqrt(lambda_lb * lambda_ub),
445+
np.sqrt(np.abs(lambda_lb * lambda_ub)),
430446
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb)
431447
)
432448

scipy/optimize/tests/test_optimize.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,6 +3257,137 @@ def f(x):
32573257
assert_allclose(res.fun, ref.fun)
32583258
assert_allclose(res.x, ref.x)
32593259

3260+
def test_gh12513_trustregion_exact_infinite_loop():
3261+
# gh-12513 reported that optimize.minimize might hang when
3262+
# method='trust-exact', using the option ``subproblem_maxiter``,
3263+
# this can be avoided.
3264+
H = np.array(
3265+
[[3.67335930e01, -2.52334820e02, 1.15477558e01, -1.19933725e-03,
3266+
-2.06408851e03, -2.05821411e00, -2.52334820e02, -6.52076924e02,
3267+
-2.71362566e-01, -1.98885126e00, 1.22085415e00, 2.30220713e00,
3268+
-9.71278532e-02, -5.11210123e-01, -1.00399562e00, 1.43319679e-01,
3269+
6.03815471e00, -6.38719934e-02, 1.65623929e-01],
3270+
[-2.52334820e02, 1.76757312e03, -9.92814996e01, 1.06533600e-02,
3271+
1.44442941e04, 1.43811694e01, 1.76757312e03, 4.56694461e03,
3272+
2.22263363e00, 1.62977318e01, -7.81539315e00, -1.24938012e01,
3273+
6.74029088e-01, 3.22802671e00, 5.14978971e00, -9.58561209e-01,
3274+
-3.92199895e01, 4.47201278e-01, -1.17866744e00],
3275+
[1.15477558e01, -9.92814996e01, 3.63872363e03, -4.40007197e-01,
3276+
-9.55435081e02, -1.13985105e00, -9.92814996e01, -2.58307255e02,
3277+
-5.21335218e01, -3.77485107e02, -6.75338369e01, -1.89457169e02,
3278+
5.67828623e00, 5.82402681e00, 1.72734354e01, -4.29114840e00,
3279+
-7.84885258e01, 3.17594634e00, 2.45242852e00],
3280+
[-1.19933725e-03, 1.06533600e-02, -4.40007197e-01, 5.73576663e-05,
3281+
1.01563710e-01, 1.18838745e-04, 1.06533600e-02, 2.76535767e-02,
3282+
6.25788669e-03, 4.50699620e-02, 8.64152333e-03, 2.27772377e-02,
3283+
-8.51026855e-04, 1.65316383e-04, 1.38977551e-03, 5.51629259e-04,
3284+
1.38447755e-02, -5.17956723e-04, -1.29260347e-04],
3285+
[-2.06408851e03, 1.44442941e04, -9.55435081e02, 1.01563710e-01,
3286+
1.23101825e05, 1.26467259e02, 1.44442941e04, 3.74590279e04,
3287+
2.18498571e01, 1.60254460e02, -7.52977260e01, -1.17989623e02,
3288+
6.58253160e00, 3.14949206e01, 4.98527190e01, -9.33338661e00,
3289+
-3.80465752e02, 4.33872213e00, -1.14768816e01],
3290+
[-2.05821411e00, 1.43811694e01, -1.13985105e00, 1.18838745e-04,
3291+
1.26467259e02, 1.46226198e-01, 1.43811694e01, 3.74509252e01,
3292+
2.76928748e-02, 2.03023837e-01, -8.84279903e-02, -1.29523344e-01,
3293+
8.06424434e-03, 3.83330661e-02, 5.81579023e-02, -1.12874980e-02,
3294+
-4.48118297e-01, 5.15022284e-03, -1.41501894e-02],
3295+
[-2.52334820e02, 1.76757312e03, -9.92814996e01, 1.06533600e-02,
3296+
1.44442941e04, 1.43811694e01, 1.76757312e03, 4.56694461e03,
3297+
2.22263363e00, 1.62977318e01, -7.81539315e00, -1.24938012e01,
3298+
6.74029088e-01, 3.22802671e00, 5.14978971e00, -9.58561209e-01,
3299+
-3.92199895e01, 4.47201278e-01, -1.17866744e00],
3300+
[-6.52076924e02, 4.56694461e03, -2.58307255e02, 2.76535767e-02,
3301+
3.74590279e04, 3.74509252e01, 4.56694461e03, 1.18278398e04,
3302+
5.82242837e00, 4.26867612e01, -2.03167952e01, -3.22894255e01,
3303+
1.75705078e00, 8.37153730e00, 1.32246076e01, -2.49238529e00,
3304+
-1.01316422e02, 1.16165466e00, -3.09390862e00],
3305+
[-2.71362566e-01, 2.22263363e00, -5.21335218e01, 6.25788669e-03,
3306+
2.18498571e01, 2.76928748e-02, 2.22263363e00, 5.82242837e00,
3307+
4.36278066e01, 3.14836583e02, -2.04747938e01, -3.05535101e01,
3308+
-1.24881456e-01, 1.15775394e01, 4.06907410e01, -1.39317748e00,
3309+
-3.90902798e01, -9.71716488e-02, 1.06851340e-01],
3310+
[-1.98885126e00, 1.62977318e01, -3.77485107e02, 4.50699620e-02,
3311+
1.60254460e02, 2.03023837e-01, 1.62977318e01, 4.26867612e01,
3312+
3.14836583e02, 2.27255216e03, -1.47029712e02, -2.19649109e02,
3313+
-8.83963155e-01, 8.28571708e01, 2.91399776e02, -9.97382920e00,
3314+
-2.81069124e02, -6.94946614e-01, 7.38151960e-01],
3315+
[1.22085415e00, -7.81539315e00, -6.75338369e01, 8.64152333e-03,
3316+
-7.52977260e01, -8.84279903e-02, -7.81539315e00, -2.03167952e01,
3317+
-2.04747938e01, -1.47029712e02, 7.83372613e01, 1.64416651e02,
3318+
-4.30243758e00, -2.59579610e01, -6.25644064e01, 6.69974667e00,
3319+
2.31011701e02, -2.68540084e00, 5.44531151e00],
3320+
[2.30220713e00, -1.24938012e01, -1.89457169e02, 2.27772377e-02,
3321+
-1.17989623e02, -1.29523344e-01, -1.24938012e01, -3.22894255e01,
3322+
-3.05535101e01, -2.19649109e02, 1.64416651e02, 3.75893031e02,
3323+
-7.42084715e00, -4.56437599e01, -1.11071032e02, 1.18761368e01,
3324+
4.78724142e02, -5.06804139e00, 8.81448081e00],
3325+
[-9.71278532e-02, 6.74029088e-01, 5.67828623e00, -8.51026855e-04,
3326+
6.58253160e00, 8.06424434e-03, 6.74029088e-01, 1.75705078e00,
3327+
-1.24881456e-01, -8.83963155e-01, -4.30243758e00, -7.42084715e00,
3328+
9.62009425e-01, 1.53836355e00, 2.23939458e00, -8.01872920e-01,
3329+
-1.92191084e01, 3.77713908e-01, -8.32946970e-01],
3330+
[-5.11210123e-01, 3.22802671e00, 5.82402681e00, 1.65316383e-04,
3331+
3.14949206e01, 3.83330661e-02, 3.22802671e00, 8.37153730e00,
3332+
1.15775394e01, 8.28571708e01, -2.59579610e01, -4.56437599e01,
3333+
1.53836355e00, 2.63851056e01, 7.34859767e01, -4.39975402e00,
3334+
-1.12015747e02, 5.11542219e-01, -2.64962727e00],
3335+
[-1.00399562e00, 5.14978971e00, 1.72734354e01, 1.38977551e-03,
3336+
4.98527190e01, 5.81579023e-02, 5.14978971e00, 1.32246076e01,
3337+
4.06907410e01, 2.91399776e02, -6.25644064e01, -1.11071032e02,
3338+
2.23939458e00, 7.34859767e01, 2.36535458e02, -1.09636675e01,
3339+
-2.72152068e02, 6.65888059e-01, -6.29295273e00],
3340+
[1.43319679e-01, -9.58561209e-01, -4.29114840e00, 5.51629259e-04,
3341+
-9.33338661e00, -1.12874980e-02, -9.58561209e-01, -2.49238529e00,
3342+
-1.39317748e00, -9.97382920e00, 6.69974667e00, 1.18761368e01,
3343+
-8.01872920e-01, -4.39975402e00, -1.09636675e01, 1.16820748e00,
3344+
3.00817252e01, -4.51359819e-01, 9.82625204e-01],
3345+
[6.03815471e00, -3.92199895e01, -7.84885258e01, 1.38447755e-02,
3346+
-3.80465752e02, -4.48118297e-01, -3.92199895e01, -1.01316422e02,
3347+
-3.90902798e01, -2.81069124e02, 2.31011701e02, 4.78724142e02,
3348+
-1.92191084e01, -1.12015747e02, -2.72152068e02, 3.00817252e01,
3349+
1.13232557e03, -1.33695932e01, 2.22934659e01],
3350+
[-6.38719934e-02, 4.47201278e-01, 3.17594634e00, -5.17956723e-04,
3351+
4.33872213e00, 5.15022284e-03, 4.47201278e-01, 1.16165466e00,
3352+
-9.71716488e-02, -6.94946614e-01, -2.68540084e00, -5.06804139e00,
3353+
3.77713908e-01, 5.11542219e-01, 6.65888059e-01, -4.51359819e-01,
3354+
-1.33695932e01, 4.27994168e-01, -5.09020820e-01],
3355+
[1.65623929e-01, -1.17866744e00, 2.45242852e00, -1.29260347e-04,
3356+
-1.14768816e01, -1.41501894e-02, -1.17866744e00, -3.09390862e00,
3357+
1.06851340e-01, 7.38151960e-01, 5.44531151e00, 8.81448081e00,
3358+
-8.32946970e-01, -2.64962727e00, -6.29295273e00, 9.82625204e-01,
3359+
2.22934659e01, -5.09020820e-01, 4.09964606e00]]
3360+
)
3361+
J = np.array([
3362+
-2.53298102e-07, 1.76392040e-06, 1.74776130e-06, -4.19479903e-10,
3363+
1.44167498e-05, 1.41703911e-08, 1.76392030e-06, 4.96030153e-06,
3364+
-2.35771675e-07, -1.68844985e-06, 4.29218258e-07, 6.65445159e-07,
3365+
-3.87045830e-08, -3.17236594e-07, -1.21120169e-06, 4.59717313e-08,
3366+
1.67123246e-06, 1.46624675e-08, 4.22723383e-08
3367+
])
3368+
3369+
def fun(x):
3370+
return np.dot(np.dot(x, H), x) / 2 + np.dot(x, J)
3371+
3372+
def jac(x):
3373+
return np.dot(x, H) + J
3374+
3375+
def hess(x):
3376+
return H
3377+
3378+
x0 = np.zeros(19)
3379+
3380+
res = optimize.minimize(
3381+
fun,
3382+
x0,
3383+
jac=jac,
3384+
hess=hess,
3385+
method="trust-exact",
3386+
options={"gtol": 1e-6, "subproblem_maxiter": 10},
3387+
)
3388+
assert res.success
3389+
assert abs(fun(res.x)) < 1e-5
3390+
32603391

32613392
@pytest.mark.parametrize('method', ['Newton-CG', 'trust-constr'])
32623393
@pytest.mark.parametrize('sparse_type', [coo_matrix, csc_matrix, csr_matrix,

0 commit comments

Comments
 (0)