11'# BFGS optimizer
2- The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
3- optimization problems. A BFGS iteration entails computing a line search
2+ The BFGS method is a quasi-Newton algorithm for solving unconstrained nonlinear
3+ optimization problems. A BFGS iteration entails computing a line search
44direction based on the gradient and Hessian approximation, finding a new point
55in the line search direction that satisfies the Wolfe conditions, and updating
66the Hessian approximation at the new point. This implementation is based on
7- BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
7+ BFGS as described in Nocedal, Jorge; Wright, Stephen J. (2006), Numerical
88Optimization (2nd ed).
99
10- 'This example demonstrates Dex's ability to do fast, stateful loops with a
11- statically unknown number of iterations. See `benchmarks/bfgs.py` for a
10+ 'This example demonstrates Dex's ability to do fast, stateful loops with a
11+ statically unknown number of iterations. See `benchmarks/bfgs.py` for a
1212comparison with Jaxopt BFGS on a multiclass logistic regression problem.
1313
1414def outer_product(x:n=>Float, y:m=>Float) -> (n=>m=>Float) given (n|Ix, m|Ix) =
1515 for i:n. for j:m. x[i]* y[j]
1616
1717def zoom(
1818 f_line: (Float)->Float,
19- a_lo_init:Float,
19+ a_lo_init:Float,
2020 a_hi_init:Float,
2121 c1:Float,
2222 c2:Float
@@ -40,7 +40,7 @@ def zoom(
4040 else
4141 f_ai = f_line a_i
4242 if f_ai > (f0 + c1 * a_i * g0) || f_ai >= f_line a_lo
43- then
43+ then
4444 a_hi_ref := a_i
4545 Continue
4646 else
@@ -81,7 +81,7 @@ def zoom_line_search(
8181 else
8282 if g_i >= 0.
8383 then Done (zoom f a_i a_last c1 c2)
84- else
84+ else
8585 a_ref_last := a_i
8686 a_ref := 0.5 * (a_i + a_max)
8787 Continue
@@ -91,7 +91,7 @@ def backtracking_line_search(
9191 f: (Float)->Float
9292 ) -> Float =
9393 -- Algorithm 3.1 in Nocedal and Wright (2006).
94- a_init = 1.
94+ a_init = 1.
9595 f_0 = f 0.
9696 g_0 = grad f 0.
9797 rho = 0.5
@@ -106,23 +106,23 @@ def backtracking_line_search(
106106 a_ref := a_i * rho
107107 Continue
108108
109- struct BFGSresults(n|Ix) =
109+ struct BFGSresults(n|Ix) =
110110 fval : Float
111111 x_opt: (n=>Float)
112- error: Float
112+ error: Float
113113 num_iter: Nat
114-
114+
115115def bfgs_minimize(
116116 f: (n=>Float)->Float, --Objective function.
117117 x0: n=>Float, --Initial point.
118118 H0: n=>n=>Float, --Initial inverse Hessian approximation.
119119 linesearch: ((Float)->Float)->Float, --Line search that returns a step size.
120120 tol: Float, --Convergence tolerance (of the gradient L2 norm).
121121 maxiter: Nat --Maximum number of BFGS iterations.
122- ) -> BFGSresults n given (n|Ix) =
122+ ) -> BFGSresults n given (n|Ix) =
123123 -- Algorithm 6.1 in Nocedal and Wright (2006).
124124
125- xref <- with_state x0
125+ xref <- with_state x0
126126 Href <- with_state H0
127127 gref <- with_state (grad f x0)
128128
@@ -137,7 +137,7 @@ def bfgs_minimize(
137137 H = get Href
138138 search_direction = -H**.g
139139 f_line = \s:Float. f (x + s .* search_direction)
140- step_size = linesearch f_line
140+ step_size = linesearch f_line
141141 x_diff = step_size .* search_direction
142142 x_next = x + x_diff
143143 g_next = grad f x_next
@@ -150,7 +150,7 @@ def bfgs_minimize(
150150 rho = 1. / rho_inv
151151 y = (eye - rho .* outer_product x_diff grad_diff)
152152 Href := y ** H ** (transpose y) + rho .* outer_product x_diff x_diff
153-
153+
154154 xref := x_next
155155 gref := g_next
156156 Continue
@@ -162,14 +162,14 @@ def rosenbrock(coord:(Fin 2)=>Float) -> Float =
162162 y = coord[1@_]
163163 pow (1 - x) 2 + 100 * pow (y - x * x) 2
164164
165- %bench "rosenbrock"
165+ %time
166166bfgs_minimize rosenbrock [10., 10.] eye (\f. backtracking_line_search 15 f) 0.001 100
167167> BFGSresults(8.668621e-13, [0.9999993, 0.9999985], 2.538457e-05, 41)
168168>
169- > rosenbrock
170- > Compile time: 618.962 ms
171- > Run time: 57.489 us (based on 1 run)
169+ > Compile time: 675.707 ms
170+ > Run time: 220.998 us
172171
172+ @noinline
173173def multiclass_logistic_loss(xs: n=>d=>Float, ys: n=>m, w: (d, m)=>Float) -> Float given (n|Ix, d|Ix, m|Ix) =
174174 w_arr = for i:d. for j:m. w[(i, j)]
175175 logits = xs ** w_arr
@@ -185,10 +185,10 @@ def multiclass_logreg(
185185 tol:Float) -> Float given (n|Ix, d|Ix, m|Ix)=
186186 ob_fun = \v. multiclass_logistic_loss xs ys v
187187 w0 = zero
188- res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
188+ res = bfgs_minimize ob_fun w0 eye (\f. zoom_line_search maxls f) tol maxiter
189189 res.fval
190190
191- -- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
191+ -- Define a version of `multiclass_logreg` with Int instead of Nat labels, callable from Python
192192-- (see benchmarks/bfgs.py).
193193def multiclass_logreg_int(
194194 xs:(Fin n)=>(Fin d)=>Float,
@@ -199,3 +199,20 @@ def multiclass_logreg_int(
199199 tol:Float) -> Float given (n, d) =
200200 y_ind = Fin (i32_to_n num_classes)
201201 multiclass_logreg xs (for i. i32_to_n ys[i] @ y_ind) (i32_to_n maxiter) (i32_to_n maxls) tol
202+
203+ n_samples = 100
204+ n_features = 20
205+ n_classes = 5
206+ maxiter = 30
207+ maxls = 15
208+ tol = 0.001
209+
210+ xs = rand_mat n_samples n_features randn (new_key 0)
211+ ys : (Fin n_samples) => (Fin n_classes) = rand_vec n_samples rand_idx (new_key 1)
212+
213+ %time
214+ multiclass_logreg xs ys maxiter maxls tol
215+ > 1.609437
216+ >
217+ > Compile time: 3.473 s
218+ > Run time: 195.542 us
0 commit comments