1
+ import logging
2
+
1
3
from collections .abc import Callable
2
4
from dataclasses import dataclass , field
3
5
8
10
from pytensor .graph import Apply , Op
9
11
from scipy .optimize import minimize
10
12
13
+ logger = logging .getLogger (__name__ )
14
+
11
15
12
16
@dataclass (slots = True )
13
17
class LBFGSHistory :
@@ -21,6 +25,7 @@ def __post_init__(self):
21
25
22
26
@dataclass (slots = True )
23
27
class LBFGSHistoryManager :
28
+ fn : Callable [[NDArray [np .float64 ]], np .float64 ]
24
29
grad_fn : Callable [[NDArray [np .float64 ]], NDArray [np .float64 ]]
25
30
x0 : NDArray [np .float64 ]
26
31
maxiter : int
@@ -32,8 +37,9 @@ def __post_init__(self) -> None:
32
37
self .x_history = np .empty ((self .maxiter + 1 , self .x0 .shape [0 ]), dtype = np .float64 )
33
38
self .g_history = np .empty ((self .maxiter + 1 , self .x0 .shape [0 ]), dtype = np .float64 )
34
39
40
+ value = self .fn (self .x0 )
35
41
grad = self .grad_fn (self .x0 )
36
- if not np .all (np .isfinite (grad )):
42
+ if np .all (np .isfinite (grad )) and np . isfinite ( value ):
37
43
self .x_history [0 ] = self .x0
38
44
self .g_history [0 ] = grad
39
45
self .count = 1
@@ -47,11 +53,16 @@ def get_history(self) -> LBFGSHistory:
47
53
return LBFGSHistory (x = self .x_history [: self .count ], g = self .g_history [: self .count ])
48
54
49
55
def __call__ (self , x : NDArray [np .float64 ]) -> None :
56
+ value = self .fn (x )
50
57
grad = self .grad_fn (x )
51
- if np .all (np .isfinite (grad )) and self .count < self .maxiter + 1 :
58
+ if np .all (np .isfinite (grad )) and np . isfinite ( value ) and self .count < self .maxiter + 1 :
52
59
self .add_entry (x , grad )
53
60
54
61
62
+ class LBFGSInitFailed (Exception ):
63
+ pass
64
+
65
+
55
66
class LBFGSOp (Op ):
56
67
def __init__ (self , fn , grad_fn , maxcor , maxiter = 1000 , ftol = 1e-5 , gtol = 1e-8 , maxls = 1000 ):
57
68
self .fn = fn
@@ -66,15 +77,18 @@ def make_node(self, x0):
66
77
x0 = pt .as_tensor_variable (x0 )
67
78
x_history = pt .dmatrix ()
68
79
g_history = pt .dmatrix ()
69
- return Apply (self , [x0 ], [x_history , g_history ])
80
+ status = pt .iscalar ()
81
+ return Apply (self , [x0 ], [x_history , g_history , status ])
70
82
71
83
def perform (self , node , inputs , outputs ):
72
84
x0 = inputs [0 ]
73
85
x0 = np .array (x0 , dtype = np .float64 )
74
86
75
- history_manager = LBFGSHistoryManager (grad_fn = self .grad_fn , x0 = x0 , maxiter = self .maxiter )
87
+ history_manager = LBFGSHistoryManager (
88
+ fn = self .fn , grad_fn = self .grad_fn , x0 = x0 , maxiter = self .maxiter
89
+ )
76
90
77
- minimize (
91
+ result = minimize (
78
92
self .fn ,
79
93
x0 ,
80
94
method = "L-BFGS-B" ,
@@ -91,5 +105,19 @@ def perform(self, node, inputs, outputs):
91
105
92
106
# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.
93
107
108
+ if result .status == 1 :
109
+ logger .info ("LBFGS maximum number of iterations reached. Consider increasing maxiter." )
110
+ elif result .status == 2 :
111
+ if (result .nit <= 1 ) or (history_manager .count <= 1 ):
112
+ logger .info (
113
+ "LBFGS failed to initialise. The model might be degenerate or the jitter might be too large."
114
+ )
115
+ raise LBFGSInitFailed ("LBFGS failed to initialise" )
116
+ elif result .fun == np .inf :
117
+ logger .info (
118
+ "LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation."
119
+ )
120
+
94
121
outputs [0 ][0 ] = history_manager .get_history ().x
95
122
outputs [1 ][0 ] = history_manager .get_history ().g
123
+ outputs [2 ][0 ] = result .status
0 commit comments