|
15 | 15 | import collections
|
16 | 16 | import logging
|
17 | 17 | import time
|
18 |
| -import warnings as _warnings |
19 | 18 |
|
20 | 19 | from collections import Counter
|
21 | 20 | from collections.abc import Callable, Iterator
|
|
40 | 39 | from pymc.model import modelcontext
|
41 | 40 | from pymc.model.core import Point
|
42 | 41 | from pymc.pytensorf import (
|
43 |
| - compile_pymc, |
| 42 | + compile, |
44 | 43 | find_rng_nodes,
|
45 | 44 | reseed_rngs,
|
46 | 45 | )
|
|
76 | 75 | )
|
77 | 76 |
|
78 | 77 | logger = logging.getLogger(__name__)
|
79 |
| -_warnings.filterwarnings( |
80 |
| - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" |
81 |
| -) |
82 | 78 |
|
83 | 79 | REGULARISATION_TERM = 1e-8
|
84 | 80 | DEFAULT_LINKER = "cvm_nogc"
|
@@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs(
|
142 | 138 | [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
|
143 | 139 | model.value_vars,
|
144 | 140 | )
|
145 |
| - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) |
| 141 | + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) |
146 | 142 | logp_dlogp_fn.trust_input = True
|
147 | 143 |
|
148 | 144 | return logp_dlogp_fn
|
@@ -502,7 +498,10 @@ def bfgs_sample_dense(
|
502 | 498 |
|
503 | 499 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
504 | 500 |
|
505 |
| - mu = x - pt.batched_dot(H_inv, g) |
| 501 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 502 | + |
| 503 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 504 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
506 | 505 |
|
507 | 506 | phi = pt.matrix_transpose(
|
508 | 507 | # (L, N, 1)
|
@@ -571,24 +570,23 @@ def bfgs_sample_sparse(
|
571 | 570 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
572 | 571 | logdet += pt.sum(pt.log(alpha), axis=-1)
|
573 | 572 |
|
| 573 | + # inverse Hessian |
| 574 | + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
| 575 | + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) |
| 576 | + |
574 | 577 | # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
|
575 |
| - mu = x - ( |
576 |
| - # (L, N), (L, N) -> (L, N) |
577 |
| - pt.batched_dot(alpha_diag, g) |
578 |
| - # beta @ gamma @ beta.T |
579 |
| - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
580 |
| - # (L, N, N), (L, N) -> (L, N) |
581 |
| - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) |
582 |
| - ) |
| 578 | + |
| 579 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 580 | + |
| 581 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 582 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
583 | 583 |
|
584 | 584 | phi = pt.matrix_transpose(
|
585 | 585 | # (L, N, 1)
|
586 | 586 | mu[..., None]
|
587 | 587 | # (L, N, N), (L, N, M) -> (L, N, M)
|
588 | 588 | + sqrt_alpha_diag
|
589 | 589 | @ (
|
590 |
| - # (L, N, 2J), (L, 2J, M) -> (L, N, M) |
591 |
| - # intermediate calcs below |
592 | 590 | # (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
|
593 | 591 | (Q @ (Lchol - IdN))
|
594 | 592 | # (L, 2J, N), (L, N, M) -> (L, 2J, M)
|
@@ -853,7 +851,7 @@ def make_pathfinder_body(
|
853 | 851 |
|
854 | 852 | # return psi, logP_psi, logQ_psi, elbo_argmax
|
855 | 853 |
|
856 |
| - pathfinder_body_fn = compile_pymc( |
| 854 | + pathfinder_body_fn = compile( |
857 | 855 | [x_full, g_full],
|
858 | 856 | [psi, logP_psi, logQ_psi, elbo_argmax],
|
859 | 857 | **compile_kwargs,
|
@@ -1565,8 +1563,9 @@ def multipath_pathfinder(
|
1565 | 1563 | task,
|
1566 | 1564 | description=desc.format(path_idx=path_idx),
|
1567 | 1565 | completed=path_idx,
|
1568 |
| - refresh=True, |
1569 | 1566 | )
|
| 1567 | + # Ensure the progress bar visually reaches 100% and shows 'Completed' |
| 1568 | + progress.update(task, completed=num_paths, description="Completed") |
1570 | 1569 | except (KeyboardInterrupt, StopIteration) as e:
|
1571 | 1570 | # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
|
1572 | 1571 | if isinstance(e, StopIteration):
|
|
0 commit comments