Skip to content

Commit a2a5e22

Browse files
committed
Handle NaNs in objective values explicitly
1 parent 66ed896 commit a2a5e22

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

dfols/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def solve_geom_system(self, rhs):
306306
return col_scale(LA.lstsq(W, col_scale(rhs * left_scaling))[0], right_scaling)
307307

308308
def interpolate_mini_models_svd(self, verbose=False, make_full_rank=False, min_sing_val=1e-6, sing_val_frac=1.0, max_jac_cond=1e8,
309-
get_chg_J=False):
309+
get_chg_J=False, throw_error_on_nans=False):
310310
W, left_scaling, right_scaling = self.interpolation_matrix()
311311
self.factorise_geom_system()
312312
ls_interp_cond_num = np.linalg.cond(W) if verbose else 0.0 # scipy.linalg does not have condition number!
@@ -327,12 +327,18 @@ def interpolate_mini_models_svd(self, verbose=False, make_full_rank=False, min_s
327327
self.model_jac = np.dot(self.model_jac, np.dot(Qhat, Qhat.T))
328328

329329
rhs = self.fval_v[fval_row_idx, :] # size npt * m
330+
if np.any(np.isnan(rhs)) and throw_error_on_nans:
331+
if self.do_logging:
332+
logging.warning("model.interpolate_mini_models_svd: NaNs encountered in objective evaluations, raising error")
333+
raise np.linalg.LinAlgError("NaN encountered in objective evaluations")
330334
try:
331335
dg = self.solve_geom_system(rhs) # size (n+1)*m
332336
except LA.LinAlgError:
333337
return False, None, None, None, None # flag error
334338
except ValueError:
335339
return False, None, None, None, None # flag error (e.g. inf or NaN encountered)
340+
if not np.all(np.isfinite(dg)): # another check for inf or NaN
341+
return False, None, None, None, None
336342
J_old = self.model_jac.copy()
337343
self.model_jac = dg[1:,:].T
338344
self.model_const = dg[0,:] - np.dot(self.model_jac, xopt) # shift base to xbase

dfols/params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, n, npt, maxfun, objfun_has_noise=False):
4444
self.params["init.random_directions_make_orthogonal"] = True # although random > orthogonal, avoid for init
4545
# Interpolation
4646
self.params["interpolation.precondition"] = True
47+
self.params["interpolation.throw_error_on_nans"] = False # throw numpy.linalg.LinAlgError if interpolating to nan data?
4748
# Logging
4849
self.params["logging.n_to_print_whole_x_vector"] = 6
4950
self.params["logging.save_diagnostic_info"] = False
@@ -142,6 +143,8 @@ def param_type(self, key, npt):
142143
type_str, nonetype_ok, lower, upper = 'bool', False, None, None
143144
elif key == "interpolation.precondition":
144145
type_str, nonetype_ok, lower, upper = 'bool', False, None, None
146+
elif key == "interpolation.throw_error_on_nans":
147+
type_str, nonetype_ok, lower, upper = 'bool', False, None, None
145148
elif key == "logging.n_to_print_whole_x_vector":
146149
type_str, nonetype_ok, lower, upper = 'int', False, 0, None
147150
elif key == "logging.save_diagnostic_info":

dfols/solver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_f
247247
min_sing_val=params("growing.full_rank.min_sing_val"),
248248
sing_val_frac=params("growing.full_rank.svd_scale_factor"),
249249
max_jac_cond=params("growing.full_rank.svd_max_jac_cond"),
250-
get_chg_J=params("restarts.use_restarts") and params("restarts.auto_detect"))
250+
get_chg_J=params("restarts.use_restarts") and params("restarts.auto_detect"),
251+
throw_error_on_nans=params("interpolation.throw_error_on_nans"))
251252
if not interp_ok:
252253
if params("restarts.use_restarts") and params("restarts.use_soft_restarts"):
253254
number_of_samples = max(nsamples(control.delta, control.rho, current_iter, nruns_so_far), 1)

dfols/tests/test_solver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def rosenbrock_jacobian(x):
4141
return np.array([[-20.0*x[0], 10.0], [-1.0, 0.0]])
4242

4343

44+
class TestNans(unittest.TestCase):
45+
# Generic objective that only returns NaNs (like optclim code)
46+
# Verify get a sensible termination
47+
def runTest(self):
48+
x0 = np.array([-1.2, 1.0])
49+
r_error = lambda x: np.array([np.nan, np.nan, np.nan])
50+
# First attempt: exit gracefully
51+
soln = dfols.solve(r_error, x0)
52+
self.assertEqual(soln.flag, soln.EXIT_LINALG_ERROR, "Wrong error message")
53+
# Second attempt: throw error when trying to interpolate
54+
with self.assertRaises(np.linalg.LinAlgError):
55+
soln = dfols.solve(r_error, x0, user_params={"interpolation.throw_error_on_nans": True})
56+
57+
4458
class TestRosenbrockGeneric(unittest.TestCase):
4559
# Minimise the (2d) Rosenbrock function
4660
def runTest(self):

0 commit comments

Comments
 (0)