Skip to content

Commit ea802fc

Browse files
committed
minor: improve error handling in Pathfinder VI
- Add LBFGSInitFailed exception for failed LBFGS initialisation - Skip failed paths in multipath_pathfinder and track number of failures - Handle NaN values from Cholesky decompsition in bfgs_sample - Add checks for numericl stabilty in matrix operations Slight performance improvements: - Set allow_gc=False in scan ops - Use FAST_RUN mode consistently
1 parent f1a54c6 commit ea802fc

File tree

3 files changed

+95
-77
lines changed

3 files changed

+95
-77
lines changed

pymc_experimental/inference/pathfinder/importance_sampling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ def psir(
8282
logger.warning(
8383
f"Pareto k value ({pareto_k:.2f}) is between 0.5 and 0.7 which indicates an imperfect approximation however still useful."
8484
)
85-
logger.info("Consider increasing ftol, gtol or maxcor.")
85+
logger.info("Consider increasing ftol, gtol, maxcor or num_paths.")
8686
elif pareto_k >= 0.7:
8787
logger.warning(
8888
f"Pareto k value ({pareto_k:.2f}) exceeds 0.7 which indicates a bad approximation."
8989
)
90-
logger.info("Consider increasing ftol, gtol, maxcor or reparametrising the model.")
90+
logger.info(
91+
"Consider increasing ftol, gtol, maxcor, num_paths or reparametrising the model."
92+
)
9193
else:
9294
logger.warning(
9395
f"Received an invalid Pareto k value of {pareto_k:.2f} which indicates the model is seriously flawed."

pymc_experimental/inference/pathfinder/lbfgs.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from collections.abc import Callable
24
from dataclasses import dataclass, field
35

@@ -8,6 +10,8 @@
810
from pytensor.graph import Apply, Op
911
from scipy.optimize import minimize
1012

13+
logger = logging.getLogger(__name__)
14+
1115

1216
@dataclass(slots=True)
1317
class LBFGSHistory:
@@ -21,6 +25,7 @@ def __post_init__(self):
2125

2226
@dataclass(slots=True)
2327
class LBFGSHistoryManager:
28+
fn: Callable[[NDArray[np.float64]], np.float64]
2429
grad_fn: Callable[[NDArray[np.float64]], NDArray[np.float64]]
2530
x0: NDArray[np.float64]
2631
maxiter: int
@@ -32,8 +37,9 @@ def __post_init__(self) -> None:
3237
self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
3338
self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
3439

40+
value = self.fn(self.x0)
3541
grad = self.grad_fn(self.x0)
36-
if not np.all(np.isfinite(grad)):
42+
if np.all(np.isfinite(grad)) and np.isfinite(value):
3743
self.x_history[0] = self.x0
3844
self.g_history[0] = grad
3945
self.count = 1
@@ -47,11 +53,16 @@ def get_history(self) -> LBFGSHistory:
4753
return LBFGSHistory(x=self.x_history[: self.count], g=self.g_history[: self.count])
4854

4955
def __call__(self, x: NDArray[np.float64]) -> None:
56+
value = self.fn(x)
5057
grad = self.grad_fn(x)
51-
if np.all(np.isfinite(grad)) and self.count < self.maxiter + 1:
58+
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
5259
self.add_entry(x, grad)
5360

5461

62+
class LBFGSInitFailed(Exception):
63+
pass
64+
65+
5566
class LBFGSOp(Op):
5667
def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
5768
self.fn = fn
@@ -66,15 +77,18 @@ def make_node(self, x0):
6677
x0 = pt.as_tensor_variable(x0)
6778
x_history = pt.dmatrix()
6879
g_history = pt.dmatrix()
69-
return Apply(self, [x0], [x_history, g_history])
80+
status = pt.iscalar()
81+
return Apply(self, [x0], [x_history, g_history, status])
7082

7183
def perform(self, node, inputs, outputs):
7284
x0 = inputs[0]
7385
x0 = np.array(x0, dtype=np.float64)
7486

75-
history_manager = LBFGSHistoryManager(grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter)
87+
history_manager = LBFGSHistoryManager(
88+
fn=self.fn, grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter
89+
)
7690

77-
minimize(
91+
result = minimize(
7892
self.fn,
7993
x0,
8094
method="L-BFGS-B",
@@ -91,5 +105,19 @@ def perform(self, node, inputs, outputs):
91105

92106
# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.
93107

108+
if result.status == 1:
109+
logger.info("LBFGS maximum number of iterations reached. Consider increasing maxiter.")
110+
elif result.status == 2:
111+
if (result.nit <= 1) or (history_manager.count <= 1):
112+
logger.info(
113+
"LBFGS failed to initialise. The model might be degenerate or the jitter might be too large."
114+
)
115+
raise LBFGSInitFailed("LBFGS failed to initialise")
116+
elif result.fun == np.inf:
117+
logger.info(
118+
"LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation."
119+
)
120+
94121
outputs[0][0] = history_manager.get_history().x
95122
outputs[1][0] = history_manager.get_history().g
123+
outputs[2][0] = result.status

0 commit comments

Comments
 (0)