Skip to content

Commit 48e93f3

Browse files
authored
Merge branch 'pymc-devs:main' into grassia2geo-dist
2 parents 71bd632 + 4d65ea0 commit 48e93f3

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import collections
1616
import logging
1717
import time
18-
import warnings as _warnings
1918

2019
from collections import Counter
2120
from collections.abc import Callable, Iterator
@@ -40,7 +39,7 @@
4039
from pymc.model import modelcontext
4140
from pymc.model.core import Point
4241
from pymc.pytensorf import (
43-
compile_pymc,
42+
compile,
4443
find_rng_nodes,
4544
reseed_rngs,
4645
)
@@ -76,9 +75,6 @@
7675
)
7776

7877
logger = logging.getLogger(__name__)
79-
_warnings.filterwarnings(
80-
"ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
81-
)
8278

8379
REGULARISATION_TERM = 1e-8
8480
DEFAULT_LINKER = "cvm_nogc"
@@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs(
142138
[model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
143139
model.value_vars,
144140
)
145-
logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
141+
logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
146142
logp_dlogp_fn.trust_input = True
147143

148144
return logp_dlogp_fn
@@ -502,7 +498,10 @@ def bfgs_sample_dense(
502498

503499
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
504500

505-
mu = x - pt.batched_dot(H_inv, g)
501+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
502+
503+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
504+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
506505

507506
phi = pt.matrix_transpose(
508507
# (L, N, 1)
@@ -571,24 +570,23 @@ def bfgs_sample_sparse(
571570
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
572571
logdet += pt.sum(pt.log(alpha), axis=-1)
573572

573+
# inverse Hessian
574+
# (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
575+
H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
576+
574577
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
575-
mu = x - (
576-
# (L, N), (L, N) -> (L, N)
577-
pt.batched_dot(alpha_diag, g)
578-
# beta @ gamma @ beta.T
579-
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
580-
# (L, N, N), (L, N) -> (L, N)
581-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
582-
)
578+
579+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
580+
581+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
582+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
583583

584584
phi = pt.matrix_transpose(
585585
# (L, N, 1)
586586
mu[..., None]
587587
# (L, N, N), (L, N, M) -> (L, N, M)
588588
+ sqrt_alpha_diag
589589
@ (
590-
# (L, N, 2J), (L, 2J, M) -> (L, N, M)
591-
# intermediate calcs below
592590
# (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
593591
(Q @ (Lchol - IdN))
594592
# (L, 2J, N), (L, N, M) -> (L, 2J, M)
@@ -853,7 +851,7 @@ def make_pathfinder_body(
853851

854852
# return psi, logP_psi, logQ_psi, elbo_argmax
855853

856-
pathfinder_body_fn = compile_pymc(
854+
pathfinder_body_fn = compile(
857855
[x_full, g_full],
858856
[psi, logP_psi, logQ_psi, elbo_argmax],
859857
**compile_kwargs,
@@ -1565,8 +1563,9 @@ def multipath_pathfinder(
15651563
task,
15661564
description=desc.format(path_idx=path_idx),
15671565
completed=path_idx,
1568-
refresh=True,
15691566
)
1567+
# Ensure the progress bar visually reaches 100% and shows 'Completed'
1568+
progress.update(task, completed=num_paths, description="Completed")
15701569
except (KeyboardInterrupt, StopIteration) as e:
15711570
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
15721571
if isinstance(e, StopIteration):

tests/test_pathfinder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import pymc as pm
1919
import pytest
2020

21-
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
22-
2321
import pymc_extras as pmx
2422

2523

@@ -53,6 +51,7 @@ def reference_idata():
5351

5452

5553
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
54+
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
5655
def test_pathfinder(inference_backend, reference_idata):
5756
if inference_backend == "blackjax" and sys.platform == "win32":
5857
pytest.skip("JAX not supported on windows")

0 commit comments

Comments
 (0)