Skip to content

Commit 1f2f2c1

Browse files
committed
Moving warning into contextmanager around batched_dot
1 parent 18a6777 commit 1f2f2c1

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def bfgs_sample_dense(
485485
shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
486486
"""
487487

488-
_warnings.simplefilter("ignore", category=FutureWarning)
489488
N = x.shape[-1]
490489
IdN = pt.eye(N)[None, ...]
491490

@@ -503,7 +502,9 @@ def bfgs_sample_dense(
503502

504503
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
505504

506-
mu = x - pt.batched_dot(H_inv, g)
505+
with _warnings.catch_warnings():
506+
_warnings.simplefilter("ignore", category=FutureWarning)
507+
mu = x - pt.batched_dot(H_inv, g)
507508

508509
phi = pt.matrix_transpose(
509510
# (L, N, 1)
@@ -561,7 +562,6 @@ def bfgs_sample_sparse(
561562
shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
562563
"""
563564

564-
_warnings.simplefilter("ignore", category=FutureWarning)
565565
# qr_input: (L, N, 2J)
566566
qr_input = inv_sqrt_alpha_diag @ beta
567567
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
@@ -574,14 +574,16 @@ def bfgs_sample_sparse(
574574
logdet += pt.sum(pt.log(alpha), axis=-1)
575575

576576
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
577-
mu = x - (
578-
# (L, N), (L, N) -> (L, N)
579-
pt.batched_dot(alpha_diag, g)
580-
# beta @ gamma @ beta.T
581-
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
582-
# (L, N, N), (L, N) -> (L, N)
583-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
584-
)
577+
with _warnings.catch_warnings():
578+
_warnings.simplefilter("ignore", category=FutureWarning)
579+
mu = x - (
580+
# (L, N), (L, N) -> (L, N)
581+
pt.batched_dot(alpha_diag, g)
582+
# beta @ gamma @ beta.T
583+
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
584+
# (L, N, N), (L, N) -> (L, N)
585+
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
586+
)
585587

586588
phi = pt.matrix_transpose(
587589
# (L, N, 1)

0 commit comments

Comments
 (0)