Skip to content

Commit a77f2c8

Browse files
committed
Progress bar and other minor changes
Major: - Added progress bar support. Minor - Added exception for non-finite log prob values - Removed . - Allowed maxcor argument to be None, and dynamically set based on the number of model parameters. - Improved logging to inform users about failed paths and lbfgs initialisation.
1 parent ea802fc commit a77f2c8

File tree

3 files changed

+131
-118
lines changed

3 files changed

+131
-118
lines changed

pymc_experimental/inference/pathfinder/lbfgs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ def __call__(self, x: NDArray[np.float64]) -> None:
6060

6161

6262
class LBFGSInitFailed(Exception):
63-
pass
63+
DEFAULT_MESSAGE = "LBFGS failed to initialise."
64+
65+
def __init__(self, message=None):
66+
if message is None:
67+
message = self.DEFAULT_MESSAGE
68+
super().__init__(message)
6469

6570

6671
class LBFGSOp(Op):
@@ -77,8 +82,7 @@ def make_node(self, x0):
7782
x0 = pt.as_tensor_variable(x0)
7883
x_history = pt.dmatrix()
7984
g_history = pt.dmatrix()
80-
status = pt.iscalar()
81-
return Apply(self, [x0], [x_history, g_history, status])
85+
return Apply(self, [x0], [x_history, g_history])
8286

8387
def perform(self, node, inputs, outputs):
8488
x0 = inputs[0]
@@ -103,21 +107,18 @@ def perform(self, node, inputs, outputs):
103107
},
104108
)
105109

106-
# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.
107-
108110
if result.status == 1:
109111
logger.info("LBFGS maximum number of iterations reached. Consider increasing maxiter.")
110-
elif result.status == 2:
111-
if (result.nit <= 1) or (history_manager.count <= 1):
112+
elif (result.status == 2) or (history_manager.count <= 1):
113+
if result.nit <= 1:
112114
logger.info(
113115
"LBFGS failed to initialise. The model might be degenerate or the jitter might be too large."
114116
)
115-
raise LBFGSInitFailed("LBFGS failed to initialise")
117+
raise LBFGSInitFailed
116118
elif result.fun == np.inf:
117119
logger.info(
118120
"LBFGS diverged to infinity. The model might be degenerate or requires reparameterisation."
119121
)
120122

121123
outputs[0][0] = history_manager.get_history().x
122124
outputs[1][0] = history_manager.get_history().g
123-
outputs[2][0] = result.status

0 commit comments

Comments
 (0)