Skip to content

Commit 55b23e5

Browse files
committed
Added optional 'args' which can be passed to objfun
1 parent e164730 commit 55b23e5

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

dfols/controller.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def able_to_do_restart(self):
9292

9393

9494
class Controller(object):
95-
def __init__(self, objfun, x0, r0, r0_nsamples, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes):
95+
def __init__(self, objfun, args, x0, r0, r0_nsamples, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes):
9696
self.objfun = objfun
97+
self.args = args
9798
self.maxfun = maxfun
9899
self.model = Model(npt, x0, r0, xl, xu, r0_nsamples, precondition=params("interpolation.precondition"),
99100
abs_tol = params("model.abs_tol"), rel_tol = params("model.rel_tol"))
@@ -405,7 +406,7 @@ def evaluate_objective(self, x, number_of_samples, params):
405406
if not incremented_nx:
406407
self.nx += 1
407408
incremented_nx = True
408-
rvec_list[i, :], f_list[i] = eval_least_squares_objective(self.objfun, remove_scaling(x, self.scaling_changes), eval_num=self.nf, pt_num=self.nx,
409+
rvec_list[i, :], f_list[i] = eval_least_squares_objective(self.objfun, remove_scaling(x, self.scaling_changes), args=self.args, eval_num=self.nf, pt_num=self.nx,
409410
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
410411
check_for_overflow=params("general.check_objfun_for_overflow"))
411412
num_samples_run += 1

dfols/solver.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __str__(self):
9191
return output
9292

9393

94-
def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf_so_far, nx_so_far, nsamples, params,
94+
def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf_so_far, nx_so_far, nsamples, params,
9595
diagnostic_info, scaling_changes, r0_avg_old=None, r0_nsamples_old=None):
9696
# Evaluate at x0 (keep nf, nx correct and check for f < 1e-12)
9797
# The hard bit is determining what m = len(r0) should be, and allocating memory appropriately
@@ -100,7 +100,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
100100
# Evaluate the first time...
101101
nf = nf_so_far + 1
102102
nx = nx_so_far + 1
103-
r0, f0 = eval_least_squares_objective(objfun, remove_scaling(x0, scaling_changes), eval_num=nf, pt_num=nx,
103+
r0, f0 = eval_least_squares_objective(objfun, remove_scaling(x0, scaling_changes), args=args, eval_num=nf, pt_num=nx,
104104
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
105105
check_for_overflow=params("general.check_objfun_for_overflow"))
106106
m = len(r0)
@@ -121,7 +121,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
121121

122122
nf += 1
123123
# Don't increment nx for x0 - we did this earlier
124-
rvec_list[i, :], f_list[i] = eval_least_squares_objective(objfun, remove_scaling(x0, scaling_changes), eval_num=nf, pt_num=nx,
124+
rvec_list[i, :], f_list[i] = eval_least_squares_objective(objfun, remove_scaling(x0, scaling_changes), args=args, eval_num=nf, pt_num=nx,
125125
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
126126
check_for_overflow=params("general.check_objfun_for_overflow"))
127127
num_samples_run += 1
@@ -142,7 +142,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
142142
nx = nx_so_far
143143

144144
# Initialise controller
145-
control = Controller(objfun, x0, r0_avg, num_samples_run, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes)
145+
control = Controller(objfun, args, x0, r0_avg, num_samples_run, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes)
146146

147147
# Initialise interpolation set
148148
number_of_samples = max(nsamples(control.delta, control.rho, 0, nruns_so_far), 1)
@@ -800,7 +800,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
800800
return x, rvec, f, jacmin, nsamples, control.nf, control.nx, nruns_so_far, exit_info, diagnostic_info
801801

802802

803-
def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=None, nsamples=None, user_params=None,
803+
def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=None, nsamples=None, user_params=None,
804804
objfun_has_noise=False, scaling_within_bounds=False):
805805
n = len(x0)
806806

@@ -943,7 +943,7 @@ def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=No
943943
nf = 0
944944
nx = 0
945945
xmin, rmin, fmin, jacmin, nsamples_min, nf, nx, nruns, exit_info, diagnostic_info = \
946-
solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
946+
solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
947947
diagnostic_info, scaling_changes)
948948

949949
# Hard restarts loop
@@ -960,11 +960,11 @@ def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=No
960960
% (fmin, nf, rhobeg, rhoend))
961961
if params("restarts.hard.use_old_rk"):
962962
xmin2, rmin2, fmin2, jacmin2, nsamples2, nf, nx, nruns, exit_info, diagnostic_info = \
963-
solve_main(objfun, xmin, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
963+
solve_main(objfun, xmin, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
964964
diagnostic_info, scaling_changes, r0_avg_old=rmin, r0_nsamples_old=nsamples_min)
965965
else:
966966
xmin2, rmin2, fmin2, jacmin2, nsamples2, nf, nx, nruns, exit_info, diagnostic_info = \
967-
solve_main(objfun, xmin, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
967+
solve_main(objfun, xmin, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
968968
diagnostic_info, scaling_changes)
969969

970970
if fmin2 < fmin or np.isnan(fmin):

dfols/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def sumsq(x):
4444
return np.dot(x, x)
4545

4646

47-
def eval_least_squares_objective(objfun, x, verbose=True, eval_num=0, pt_num=0, full_x_thresh=6, check_for_overflow=True):
47+
def eval_least_squares_objective(objfun, x, args=(), verbose=True, eval_num=0, pt_num=0, full_x_thresh=6, check_for_overflow=True):
4848
# Evaluate least squares function
49-
fvec = objfun(x)
49+
fvec = objfun(x, *args)
5050

5151
if check_for_overflow:
5252
try:

docs/userguide.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ The :code:`solve` function has several optional arguments which the user may pro
6565

6666
.. code-block:: python
6767
68-
dfols.solve(objfun, x0, bounds=None, npt=None, rhobeg=None,
68+
dfols.solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None,
6969
rhoend=1e-8, maxfun=None, nsamples=None,
7070
user_params=None, objfun_has_noise=False,
7171
scaling_within_bounds=False)
7272
7373
These arguments are:
7474

75+
* :code:`args` - a tuple of extra arguments passed to the objective function. This feature is new, and not yet avaiable in the PyPI version of DFO-LS; instead, use Python's built-in function :code:`lambda`.
7576
* :code:`bounds` - a tuple :code:`(lower, upper)` with the vectors :math:`a` and :math:`b` of lower and upper bounds on :math:`x` (default is :math:`a_i=-10^{20}` and :math:`b_i=10^{20}`). To set bounds for either :code:`lower` or :code:`upper`, but not both, pass a tuple :code:`(lower, None)` or :code:`(None, upper)`.
7677
* :code:`npt` - the number of interpolation points to use (default is :code:`len(x0)+1`). If using restarts, this is the number of points to use in the first run of the solver, before any restarts (and may be optionally increased via settings in :code:`user_params`).
7778
* :code:`rhobeg` - the initial value of the trust region radius (default is :math:`0.1\max(\|x_0\|_{\infty}, 1)`).

0 commit comments

Comments
 (0)