2
2
from typing import NamedTuple
3
3
4
4
import numpy as np
5
+ import pytensor .tensor as pt
5
6
7
+ from pytensor .graph import Apply , Op
6
8
from scipy .optimize import minimize
7
9
8
10
9
11
class LBFGSHistory (NamedTuple ):
10
12
x : np .ndarray
11
- f : np .ndarray
12
13
g : np .ndarray
13
14
14
15
15
16
class LBFGSHistoryManager :
16
- def __init__ (self , fn : Callable , grad_fn : Callable , x0 : np .ndarray , maxiter : int ):
17
+ def __init__ (self , grad_fn : Callable , x0 : np .ndarray , maxiter : int ):
17
18
dim = x0 .shape [0 ]
18
19
maxiter_add_one = maxiter + 1
19
20
# Pre-allocate arrays to save memory and improve speed
20
21
self .x_history = np .empty ((maxiter_add_one , dim ), dtype = np .float64 )
21
- self .f_history = np .empty (maxiter_add_one , dtype = np .float64 )
22
22
self .g_history = np .empty ((maxiter_add_one , dim ), dtype = np .float64 )
23
23
self .count = 0
24
- self .fn = fn
25
24
self .grad_fn = grad_fn
26
- self .add_entry (x0 , fn ( x0 ), grad_fn (x0 ))
25
+ self .add_entry (x0 , grad_fn (x0 ))
27
26
28
- def add_entry (self , x , f , g = None ):
27
+ def add_entry (self , x , g ):
29
28
self .x_history [self .count ] = x
30
- self .f_history [self .count ] = f
31
- if self .g_history is not None and g is not None :
32
- self .g_history [self .count ] = g
29
+ self .g_history [self .count ] = g
33
30
self .count += 1
34
31
35
32
def get_history (self ):
36
- # Return trimmed arrays up to the number of entries actually used
33
+ # Return trimmed arrays up to L << L^max
37
34
x = self .x_history [: self .count ]
38
- f = self .f_history [: self .count ]
39
- g = self .g_history [: self .count ] if self .g_history is not None else None
35
+ g = self .g_history [: self .count ]
40
36
return LBFGSHistory (
41
37
x = x ,
42
- f = f ,
43
38
g = g ,
44
39
)
45
40
46
41
def __call__ (self , x ):
47
- self .add_entry (x , self .fn (x ), self .grad_fn (x ))
42
+ grad = self .grad_fn (x )
43
+ if np .all (np .isfinite (grad )):
44
+ self .add_entry (x , grad )
48
45
49
46
50
47
def lbfgs (
@@ -62,7 +59,6 @@ def callback(xk):
62
59
lbfgs_history_manager (xk )
63
60
64
61
lbfgs_history_manager = LBFGSHistoryManager (
65
- fn = fn ,
66
62
grad_fn = grad_fn ,
67
63
x0 = x0 ,
68
64
maxiter = maxiter ,
@@ -89,4 +85,58 @@ def callback(xk):
89
85
callback = callback ,
90
86
** lbfgs_kwargs ,
91
87
)
92
- return lbfgs_history_manager .get_history ()
88
+ lbfgs_history = lbfgs_history_manager .get_history ()
89
+ return lbfgs_history .x , lbfgs_history .g
90
+
91
+
92
+ class LBFGSOp (Op ):
93
+ def __init__ (self , fn , grad_fn , maxcor , maxiter = 1000 , ftol = 1e-5 , gtol = 1e-8 , maxls = 1000 ):
94
+ self .fn = fn
95
+ self .grad_fn = grad_fn
96
+ self .maxcor = maxcor
97
+ self .maxiter = maxiter
98
+ self .ftol = ftol
99
+ self .gtol = gtol
100
+ self .maxls = maxls
101
+
102
+ def make_node (self , x0 ):
103
+ x0 = pt .as_tensor_variable (x0 )
104
+ x_history = pt .dmatrix ()
105
+ g_history = pt .dmatrix ()
106
+ return Apply (self , [x0 ], [x_history , g_history ])
107
+
108
+ def perform (self , node , inputs , outputs ):
109
+ x0 = inputs [0 ]
110
+ x0 = np .array (x0 , dtype = np .float64 )
111
+
112
+ history_manager = LBFGSHistoryManager (grad_fn = self .grad_fn , x0 = x0 , maxiter = self .maxiter )
113
+
114
+ minimize (
115
+ self .fn ,
116
+ x0 ,
117
+ method = "L-BFGS-B" ,
118
+ jac = self .grad_fn ,
119
+ callback = history_manager ,
120
+ options = {
121
+ "maxcor" : self .maxcor ,
122
+ "maxiter" : self .maxiter ,
123
+ "ftol" : self .ftol ,
124
+ "gtol" : self .gtol ,
125
+ "maxls" : self .maxls ,
126
+ },
127
+ )
128
+
129
+ # fmin_l_bfgs_b(
130
+ # func=self.fn,
131
+ # fprime=self.grad_fn,
132
+ # x0=x0,
133
+ # pgtol=self.gtol,
134
+ # factr=self.ftol / np.finfo(float).eps,
135
+ # maxls=self.maxls,
136
+ # maxiter=self.maxiter,
137
+ # m=self.maxcor,
138
+ # callback=history_manager,
139
+ # )
140
+
141
+ outputs [0 ][0 ] = history_manager .get_history ().x
142
+ outputs [1 ][0 ] = history_manager .get_history ().g
0 commit comments