Skip to content

Commit 32fa413

Browse files
committed
Updated LBFGS status handling and alpha_recover function
- Corrected the condition for LOW_UPDATE_PCT in LBFGS status handling. - Removed update_mask references in alpha_recover and inverse_hessian_factors - Adjusted test cases to reflect changes in status messages and function signatures.
1 parent 90bba5d commit 32fa413

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

pymc_extras/inference/pathfinder/lbfgs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
218218
elif result.status == 2:
219219
# precision loss resulting to inf or nan
220220
lbfgs_status = LBFGSStatus.NON_FINITE
221-
elif history.count < low_update_threshold * result.nit:
221+
elif history.count * low_update_threshold < result.nit:
222222
lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
223223
else:
224224
lbfgs_status = LBFGSStatus.CONVERGED

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def alpha_recover(
262262
shapes: L=batch_size, N=num_params
263263
"""
264264

265-
def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
265+
def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
266266
# alpha_lm1: (N,)
267267
# s_l: (N,)
268268
# z_l: (N,)
@@ -290,7 +290,7 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
290290
)
291291

292292
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
293-
# alpha: (L, N), update_mask: (L, N)
293+
# alpha: (L, N)
294294
return alpha, s, z
295295

296296

@@ -368,8 +368,8 @@ def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
368368
L, N = alpha.shape
369369

370370
# changed to get_chi_matrix_2 after removing update_mask
371-
S = get_chi_matrix_1(s, J)
372-
Z = get_chi_matrix_1(z, J)
371+
S = get_chi_matrix_2(s, J)
372+
Z = get_chi_matrix_2(z, J)
373373

374374
# E: (L, J, J)
375375
Ij = pt.eye(J)[None, ...]

tests/test_pathfinder.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
106106
)
107107
out, err = capsys.readouterr()
108108
status_pattern = [
109-
r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+",
110-
r"LOW_UPDATE_MASK_RATIO\s+\d+",
109+
r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+",
110+
r"LOW_UPDATE_PCT\s+\d+",
111111
r"LBFGS_FAILED\s+\d+",
112112
r"SUCCESS\s+\d+",
113113
]
@@ -126,8 +126,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter):
126126
out, err = capsys.readouterr()
127127

128128
status_pattern = [
129-
r"INIT_FAILED_LOW_UPDATE_MASK\s+2",
130-
r"LOW_UPDATE_MASK_RATIO\s+2",
129+
r"INIT_FAILED_LOW_UPDATE_PCT\s+2",
130+
r"LOW_UPDATE_PCT\s+2",
131131
r"LBFGS_FAILED\s+4",
132132
]
133133
for pattern in status_pattern:
@@ -232,12 +232,11 @@ def test_bfgs_sample():
232232
# get factors
233233
x_full = pt.as_tensor(x_data, dtype="float64")
234234
g_full = pt.as_tensor(g_data, dtype="float64")
235-
epsilon = 1e-11
236235

237236
x = x_full[1:]
238237
g = g_full[1:]
239-
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
240-
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
238+
alpha, s, z = alpha_recover(x_full, g_full)
239+
beta, gamma = inverse_hessian_factors(alpha, s, z, J)
241240

242241
# sample
243242
phi, logq = bfgs_sample(
@@ -252,8 +251,8 @@ def test_bfgs_sample():
252251
# check shapes
253252
assert beta.eval().shape == (L, N, 2 * J)
254253
assert gamma.eval().shape == (L, 2 * J, 2 * J)
255-
assert phi.eval().shape == (L, num_samples, N)
256-
assert logq.eval().shape == (L, num_samples)
254+
assert all(phi.shape.eval() == (L, num_samples, N))
255+
assert all(logq.shape.eval() == (L, num_samples))
257256

258257

259258
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])

0 commit comments

Comments
 (0)