4141# Ensure compatibility with Python 2
4242from __future__ import absolute_import , division , print_function , unicode_literals
4343
44- import logging
4544from math import sqrt
4645import numpy as np
46+ try :
47+ import trustregion
48+ USE_FORTRAN = True
49+ except ImportError :
50+ # Fall back to Python implementation
51+ USE_FORTRAN = False
4752
4853
4954from .util import sumsq , model_value
5055
5156
5257__all__ = ['trsbox' , 'trsbox_geometry' ]
5358
54- # ZERO_THRESH = 1e-14
59+ ZERO_THRESH = 1e-14
5560
5661
57- def trsbox (xopt , g , H , sl , su , delta ):
62+ def trsbox (xopt , g , H , sl , su , delta , use_fortran = USE_FORTRAN ):
63+ if use_fortran :
64+ return trustregion .solve (g , H , delta ,
65+ sl = np .minimum (sl - xopt , - ZERO_THRESH ),
66+ su = np .maximum (su - xopt , ZERO_THRESH ),
67+ verbose_output = True )
68+
5869 n = xopt .size
5970 assert xopt .shape == (n ,), "xopt has wrong shape (should be vector)"
6071 assert g .shape == (n ,), "g and xopt have incompatible sizes"
@@ -368,7 +379,7 @@ def d_within_bounds(d, xopt, sl, su, xbdi):
368379 return d
369380
370381
371- def trsbox_geometry (xbase , c , g , H , lower , upper , Delta ):
382+ def trsbox_geometry (xbase , c , g , H , lower , upper , Delta , use_fortran = USE_FORTRAN ):
372383 # Given a Lagrange polynomial defined by: L(x) = c + g' * (x - xbase) + 0.5*(x-xbase)*H*(x-xbase)
373384 # Maximise |L(x)| in a box + trust region - that is, solve:
374385 # max_x abs(c + g' * (x - xbase) + 0.5*(x-xbase)*H*(x-xbase))
@@ -378,8 +389,8 @@ def trsbox_geometry(xbase, c, g, H, lower, upper, Delta):
378389 # max_s abs(c + g' * s + 0.5*s*H*s)
379390 # s.t. lower <= xbase + s <= upper
380391 # ||s|| <= Delta
381- smin , gmin , crvmin = trsbox (xbase , g , H , lower , upper , Delta ) # minimise L(x)
382- smax , gmax , crvmax = trsbox (xbase , - g , - H , lower , upper , Delta ) # maximise L(x)
392+ smin , gmin , crvmin = trsbox (xbase , g , H , lower , upper , Delta , use_fortran = use_fortran ) # minimise L(x)
393+ smax , gmax , crvmax = trsbox (xbase , - g , - H , lower , upper , Delta , use_fortran = use_fortran ) # maximise L(x)
383394 if abs (c + model_value (g , H , smin )) >= abs (c + model_value (g , H , smax )): # take largest abs value
384395 return xbase + smin
385396 else :
0 commit comments