1+ import logging
2+
13from collections .abc import Callable
24from dataclasses import dataclass , field
35
810from pytensor .graph import Apply , Op
911from scipy .optimize import minimize
1012
13+ logger = logging .getLogger (__name__ )
14+
1115
1216@dataclass (slots = True )
1317class LBFGSHistory :
@@ -21,6 +25,7 @@ def __post_init__(self):
2125
2226@dataclass (slots = True )
2327class LBFGSHistoryManager :
28+ fn : Callable [[NDArray [np .float64 ]], np .float64 ]
2429 grad_fn : Callable [[NDArray [np .float64 ]], NDArray [np .float64 ]]
2530 x0 : NDArray [np .float64 ]
2631 maxiter : int
@@ -32,8 +37,9 @@ def __post_init__(self) -> None:
3237 self .x_history = np .empty ((self .maxiter + 1 , self .x0 .shape [0 ]), dtype = np .float64 )
3338 self .g_history = np .empty ((self .maxiter + 1 , self .x0 .shape [0 ]), dtype = np .float64 )
3439
40+ value = self .fn (self .x0 )
3541 grad = self .grad_fn (self .x0 )
36- if not np .all (np .isfinite (grad )):
42+ if np .all (np .isfinite (grad )) and np . isfinite ( value ):
3743 self .x_history [0 ] = self .x0
3844 self .g_history [0 ] = grad
3945 self .count = 1
@@ -47,11 +53,16 @@ def get_history(self) -> LBFGSHistory:
4753 return LBFGSHistory (x = self .x_history [: self .count ], g = self .g_history [: self .count ])
4854
4955 def __call__ (self , x : NDArray [np .float64 ]) -> None :
56+ value = self .fn (x )
5057 grad = self .grad_fn (x )
51- if np .all (np .isfinite (grad )) and self .count < self .maxiter + 1 :
58+ if np .all (np .isfinite (grad )) and np . isfinite ( value ) and self .count < self .maxiter + 1 :
5259 self .add_entry (x , grad )
5360
5461
62+ class LBFGSInitFailed (Exception ):
63+ pass
64+
65+
5566class LBFGSOp (Op ):
5667 def __init__ (self , fn , grad_fn , maxcor , maxiter = 1000 , ftol = 1e-5 , gtol = 1e-8 , maxls = 1000 ):
5768 self .fn = fn
@@ -66,15 +77,18 @@ def make_node(self, x0):
6677 x0 = pt .as_tensor_variable (x0 )
6778 x_history = pt .dmatrix ()
6879 g_history = pt .dmatrix ()
69- return Apply (self , [x0 ], [x_history , g_history ])
80+ status = pt .iscalar ()
81+ return Apply (self , [x0 ], [x_history , g_history , status ])
7082
7183 def perform (self , node , inputs , outputs ):
7284 x0 = inputs [0 ]
7385 x0 = np .array (x0 , dtype = np .float64 )
7486
75- history_manager = LBFGSHistoryManager (grad_fn = self .grad_fn , x0 = x0 , maxiter = self .maxiter )
87+ history_manager = LBFGSHistoryManager (
88+ fn = self .fn , grad_fn = self .grad_fn , x0 = x0 , maxiter = self .maxiter
89+ )
7690
77- minimize (
91+ result = minimize (
7892 self .fn ,
7993 x0 ,
8094 method = "L-BFGS-B" ,
@@ -91,5 +105,19 @@ def perform(self, node, inputs, outputs):
91105
92106 # TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.
93107
108+ if result .status == 1 :
109+ logger .info ("LBFGS maximum number of iterations reached. Consider increasing maxiter." )
110+ elif result .status == 2 :
111+ if (result .nit <= 1 ) or (history_manager .count <= 1 ):
112+ logger .info (
113+ "LBFGS failed to initialise. The model might be degenerate or the jitter might be too large."
114+ )
115+ raise LBFGSInitFailed ("LBFGS failed to initialise" )
116+ elif result .fun == np .inf :
117+ logger .info (
118+ "LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation."
119+ )
120+
94121 outputs [0 ][0 ] = history_manager .get_history ().x
95122 outputs [1 ][0 ] = history_manager .get_history ().g
123+ outputs [2 ][0 ] = result .status
0 commit comments