Skip to content

Commit 830217d

Browse files
committed
Bugfix in passing extra arguments
1 parent 6da5e2e commit 830217d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

dfols/solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,12 +1036,12 @@ def solve(objfun, x0, h=None, lh=None, prox_uh=None, argsf=(), argsh=(), argspro
10361036
x0 = xp.copy()
10371037

10381038
# Enforce lower & upper bounds on x0
1039-
idx = (x0 <= xl)
1039+
idx = (x0 < xl)
10401040
if np.any(idx):
10411041
warnings.warn("x0 below lower bound, adjusting", RuntimeWarning)
10421042
x0[idx] = xl[idx]
10431043

1044-
idx = (x0 >= xu)
1044+
idx = (x0 > xu)
10451045
if np.any(idx):
10461046
warnings.warn("x0 above upper bound, adjusting", RuntimeWarning)
10471047
x0[idx] = xu[idx]

dfols/trust_region.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
# Fall back to Python implementation
8080
USE_FORTRAN = False
8181

82-
from .util import dykstra, pball, pbox, sumsq, model_value
82+
from .util import dykstra, pball, pbox, sumsq, model_value, remove_scaling
8383

8484
__all__ = ['ctrsbox_sfista', 'ctrsbox_pgd', 'ctrsbox_geometry', 'trsbox', 'trsbox_geometry']
8585

@@ -135,7 +135,7 @@ def proj(d0):
135135
return p - xopt
136136

137137
# general step
138-
model_value_best = model_value(g, H, d, xopt, h, *argsh, scaling_changes)
138+
model_value_best = model_value(g, H, d, xopt, h, argsh, scaling_changes)
139139
d_best = d.copy()
140140
for k in range(MAX_LOOP_ITERS):
141141
prev_d = d.copy()
@@ -148,7 +148,7 @@ def proj(d0):
148148
# SOLVED: (previously) make sfista decrease in each iteration (might have d = 0, criticality measure=0)
149149
# if model_value(g, H, d, xopt, h, *argsh) > model_value(g, H, prev_d, xopt, h, *argsh):
150150
# d = prev_d
151-
new_model_value = model_value(g, H, d, xopt, h, *argsh, scaling_changes)
151+
new_model_value = model_value(g, H, d, xopt, h, argsh, scaling_changes)
152152
if new_model_value < model_value_best:
153153
d_best = d.copy()
154154
model_value_best = new_model_value

dfols/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def eval_least_squares_with_regularisation(objfun, x, h=None, argsf=(), argsh=()
6464

6565
# objective = least-squares + regularisation
6666
obj = f
67-
if h != None:
67+
if h is not None:
6868
# Evaluate regularisation term
6969
hvalue = h(x, *argsh)
7070
obj = f + hvalue
@@ -83,7 +83,7 @@ def model_value(g, H, s, xopt=(), h=None,argsh=(), scaling_changes=None):
8383
assert g.shape == s.shape, "g and s have incompatible sizes"
8484
Hs = H.dot(s)
8585
rtn = np.dot(s, g + 0.5*Hs)
86-
if h != None:
86+
if h is not None:
8787
hvalue = h(remove_scaling(xopt+s, scaling_changes), *argsh)
8888
rtn += hvalue
8989
return rtn

0 commit comments

Comments
 (0)