|
40 | 40 | from pymc.model import modelcontext |
41 | 41 | from pymc.model.core import Point |
42 | 42 | from pymc.pytensorf import ( |
43 | | - compile_pymc, |
| 43 | + compile, |
44 | 44 | find_rng_nodes, |
45 | 45 | reseed_rngs, |
46 | 46 | ) |
|
77 | 77 |
|
78 | 78 | logger = logging.getLogger(__name__) |
79 | 79 | _warnings.filterwarnings( |
80 | | - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" |
| 80 | + "ignore", |
| 81 | + category=UserWarning, |
| 82 | + message="The same einsum subscript is used for a broadcastable and non-broadcastable dimension", |
81 | 83 | ) |
82 | 84 |
|
83 | 85 | REGULARISATION_TERM = 1e-8 |
@@ -142,7 +144,7 @@ def get_logp_dlogp_of_ravel_inputs( |
142 | 144 | [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)], |
143 | 145 | model.value_vars, |
144 | 146 | ) |
145 | | - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) |
| 147 | + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) |
146 | 148 | logp_dlogp_fn.trust_input = True |
147 | 149 |
|
148 | 150 | return logp_dlogp_fn |
@@ -502,7 +504,7 @@ def bfgs_sample_dense( |
502 | 504 |
|
503 | 505 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) |
504 | 506 |
|
505 | | - mu = x - pt.batched_dot(H_inv, g) |
| 507 | + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) |
506 | 508 |
|
507 | 509 | phi = pt.matrix_transpose( |
508 | 510 | # (L, N, 1) |
@@ -571,15 +573,12 @@ def bfgs_sample_sparse( |
571 | 573 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) |
572 | 574 | logdet += pt.sum(pt.log(alpha), axis=-1) |
573 | 575 |
|
| 576 | + # inverse Hessian |
| 577 | + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
| 578 | + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) |
| 579 | + |
574 | 580 | # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. |
575 | | - mu = x - ( |
576 | | - # (L, N), (L, N) -> (L, N) |
577 | | - pt.batched_dot(alpha_diag, g) |
578 | | - # beta @ gamma @ beta.T |
579 | | - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
580 | | - # (L, N, N), (L, N) -> (L, N) |
581 | | - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) |
582 | | - ) |
| 581 | + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) |
583 | 582 |
|
584 | 583 | phi = pt.matrix_transpose( |
585 | 584 | # (L, N, 1) |
@@ -853,7 +852,7 @@ def make_pathfinder_body( |
853 | 852 |
|
854 | 853 | # return psi, logP_psi, logQ_psi, elbo_argmax |
855 | 854 |
|
856 | | - pathfinder_body_fn = compile_pymc( |
| 855 | + pathfinder_body_fn = compile( |
857 | 856 | [x_full, g_full], |
858 | 857 | [psi, logP_psi, logQ_psi, elbo_argmax], |
859 | 858 | **compile_kwargs, |
|
0 commit comments