Skip to content

Commit 65dce60

Browse files
committed
Extend the BFGS example to test the larger logistic regression loss.
Also - Tag the loss function noinline, which speeds up compilation ~2x, - Use %time rather than %bench in the example, and - Fix whitespace.
1 parent 808f6e9 commit 65dce60

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

benchmarks/bfgs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def multiclass_logreg_jaxopt(X, y):
2626
fun = jaxopt.objective.multiclass_logreg
2727
init = jnp.zeros((X.shape[1], FLAGS.n_classes))
2828
bfgs = jaxopt.BFGS(
29-
fun=fun,
29+
fun=fun,
3030
linesearch='zoom',
3131
maxiter=FLAGS.maxiter,
3232
maxls=FLAGS.maxls,
@@ -59,8 +59,8 @@ def main(argv):
5959

6060
start_time = time.time()
6161
dex_value = dex_bfgs(
62-
jnp.array(X),
63-
jnp.array(y),
62+
jnp.array(X),
63+
jnp.array(y),
6464
FLAGS.n_classes,
6565
FLAGS.maxiter,
6666
FLAGS.maxls,

examples/bfgs.dx

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
'# BFGS optimizer
2-
The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
3-
optimization problems. A BFGS iteration entails computing a line search
2+
The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
3+
optimization problems. A BFGS iteration entails computing a line search
44
direction based on the gradient and Hessian approximation, finding a new point
55
in the line search direction that satisfies the Wolfe conditions, and updating
66
the Hessian approximation at the new point. This implementation is based on
7-
BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
7+
BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
88
Optimization (2nd ed).
99

10-
'This example demonstrates Dex's ability to do fast, stateful loops with a
11-
statically unknown number of iterations. See `benchmarks/bfgs.py` for a
10+
'This example demonstrates Dex's ability to do fast, stateful loops with a
11+
statically unknown number of iterations. See `benchmarks/bfgs.py` for a
1212
comparison with Jaxopt BFGS on a multiclass logistic regression problem.
1313

1414
def outer_product(x:n=>Float, y:m=>Float) -> (n=>m=>Float) given (n|Ix, m|Ix) =
1515
for i:n. for j:m. x[i]* y[j]
1616

1717
def zoom(
1818
f_line: (Float)->Float,
19-
a_lo_init:Float,
19+
a_lo_init:Float,
2020
a_hi_init:Float,
2121
c1:Float,
2222
c2:Float
@@ -40,7 +40,7 @@ def zoom(
4040
else
4141
f_ai = f_line a_i
4242
if f_ai > (f0 + c1 * a_i * g0) || f_ai >= f_line a_lo
43-
then
43+
then
4444
a_hi_ref := a_i
4545
Continue
4646
else
@@ -81,7 +81,7 @@ def zoom_line_search(
8181
else
8282
if g_i >= 0.
8383
then Done (zoom f a_i a_last c1 c2)
84-
else
84+
else
8585
a_ref_last := a_i
8686
a_ref := 0.5 * (a_i + a_max)
8787
Continue
@@ -91,7 +91,7 @@ def backtracking_line_search(
9191
f: (Float)->Float
9292
) -> Float =
9393
-- Algorithm 3.1 in Nocedal and Wright (2006).
94-
a_init = 1.
94+
a_init = 1.
9595
f_0 = f 0.
9696
g_0 = grad f 0.
9797
rho = 0.5
@@ -106,23 +106,23 @@ def backtracking_line_search(
106106
a_ref := a_i * rho
107107
Continue
108108

109-
struct BFGSresults(n|Ix) =
109+
struct BFGSresults(n|Ix) =
110110
fval : Float
111111
x_opt: (n=>Float)
112-
error: Float
112+
error: Float
113113
num_iter: Nat
114-
114+
115115
def bfgs_minimize(
116116
f: (n=>Float)->Float, --Objective function.
117117
x0: n=>Float, --Initial point.
118118
H0: n=>n=>Float, --Initial inverse Hessian approximation.
119119
linesearch: ((Float)->Float)->Float, --Line search that returns a step size.
120120
tol: Float, --Convergence tolerance (of the gradient L2 norm).
121121
maxiter: Nat --Maximum number of BFGS iterations.
122-
) -> BFGSresults n given (n|Ix) =
122+
) -> BFGSresults n given (n|Ix) =
123123
-- Algorithm 6.1 in Nocedal and Wright (2006).
124124

125-
xref <- with_state x0
125+
xref <- with_state x0
126126
Href <- with_state H0
127127
gref <- with_state (grad f x0)
128128

@@ -137,7 +137,7 @@ def bfgs_minimize(
137137
H = get Href
138138
search_direction = -H**.g
139139
f_line = \s:Float. f (x + s .* search_direction)
140-
step_size = linesearch f_line
140+
step_size = linesearch f_line
141141
x_diff = step_size .* search_direction
142142
x_next = x + x_diff
143143
g_next = grad f x_next
@@ -150,7 +150,7 @@ def bfgs_minimize(
150150
rho = 1. / rho_inv
151151
y = (eye - rho .* outer_product x_diff grad_diff)
152152
Href := y ** H ** (transpose y) + rho .* outer_product x_diff x_diff
153-
153+
154154
xref := x_next
155155
gref := g_next
156156
Continue
@@ -162,14 +162,14 @@ def rosenbrock(coord:(Fin 2)=>Float) -> Float =
162162
y = coord[1@_]
163163
pow (1 - x) 2 + 100 * pow (y - x * x) 2
164164

165-
%bench "rosenbrock"
165+
%time
166166
bfgs_minimize rosenbrock [10., 10.] eye (\f. backtracking_line_search 15 f) 0.001 100
167167
> BFGSresults(8.668621e-13, [0.9999993, 0.9999985], 2.538457e-05, 41)
168168
>
169-
> rosenbrock
170-
> Compile time: 618.962 ms
171-
> Run time: 57.489 us (based on 1 run)
169+
> Compile time: 675.707 ms
170+
> Run time: 220.998 us
172171

172+
@noinline
173173
def multiclass_logistic_loss(xs: n=>d=>Float, ys: n=>m, w: (d, m)=>Float) -> Float given (n|Ix, d|Ix, m|Ix) =
174174
w_arr = for i:d. for j:m. w[(i, j)]
175175
logits = xs ** w_arr
@@ -185,10 +185,10 @@ def multiclass_logreg(
185185
tol:Float) -> Float given (n|Ix, d|Ix, m|Ix)=
186186
ob_fun = \v. multiclass_logistic_loss xs ys v
187187
w0 = zero
188-
res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
188+
res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
189189
res.fval
190190

191-
-- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
191+
-- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
192192
-- (see benchmarks/bfgs.py).
193193
def multiclass_logreg_int(
194194
xs:(Fin n)=>(Fin d)=>Float,
@@ -199,3 +199,20 @@ def multiclass_logreg_int(
199199
tol:Float) -> Float given (n, d) =
200200
y_ind = Fin (i32_to_n num_classes)
201201
multiclass_logreg xs (for i. i32_to_n ys[i] @ y_ind) (i32_to_n maxiter) (i32_to_n maxls) tol
202+
203+
n_samples = 100
204+
n_features = 20
205+
n_classes = 5
206+
maxiter = 30
207+
maxls = 15
208+
tol = 0.001
209+
210+
xs = rand_mat n_samples n_features randn (new_key 0)
211+
ys : (Fin n_samples) => (Fin n_classes) = rand_vec n_samples rand_idx (new_key 1)
212+
213+
%time
214+
multiclass_logreg xs ys maxiter maxls tol
215+
> 1.609437
216+
>
217+
> Compile time: 3.473 s
218+
> Run time: 195.542 us

0 commit comments

Comments
 (0)