@@ -485,7 +485,6 @@ def bfgs_sample_dense(
485
485
shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
486
486
"""
487
487
488
- _warnings .simplefilter ("ignore" , category = FutureWarning )
489
488
N = x .shape [- 1 ]
490
489
IdN = pt .eye (N )[None , ...]
491
490
@@ -503,7 +502,9 @@ def bfgs_sample_dense(
503
502
504
503
logdet = 2.0 * pt .sum (pt .log (pt .abs (pt .diagonal (Lchol , axis1 = - 2 , axis2 = - 1 ))), axis = - 1 )
505
504
506
- mu = x - pt .batched_dot (H_inv , g )
505
+ with _warnings .catch_warnings ():
506
+ _warnings .simplefilter ("ignore" , category = FutureWarning )
507
+ mu = x - pt .batched_dot (H_inv , g )
507
508
508
509
phi = pt .matrix_transpose (
509
510
# (L, N, 1)
@@ -561,7 +562,6 @@ def bfgs_sample_sparse(
561
562
shapes: L=batch_size, N=num_params, J=history_size, M=num_samples
562
563
"""
563
564
564
- _warnings .simplefilter ("ignore" , category = FutureWarning )
565
565
# qr_input: (L, N, 2J)
566
566
qr_input = inv_sqrt_alpha_diag @ beta
567
567
(Q , R ), _ = pytensor .scan (fn = pt .nlinalg .qr , sequences = [qr_input ], allow_gc = False )
@@ -574,14 +574,16 @@ def bfgs_sample_sparse(
574
574
logdet += pt .sum (pt .log (alpha ), axis = - 1 )
575
575
576
576
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
577
- mu = x - (
578
- # (L, N), (L, N) -> (L, N)
579
- pt .batched_dot (alpha_diag , g )
580
- # beta @ gamma @ beta.T
581
- # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
582
- # (L, N, N), (L, N) -> (L, N)
583
- + pt .batched_dot ((beta @ gamma @ pt .matrix_transpose (beta )), g )
584
- )
577
+ with _warnings .catch_warnings ():
578
+ _warnings .simplefilter ("ignore" , category = FutureWarning )
579
+ mu = x - (
580
+ # (L, N), (L, N) -> (L, N)
581
+ pt .batched_dot (alpha_diag , g )
582
+ # beta @ gamma @ beta.T
583
+ # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
584
+ # (L, N, N), (L, N) -> (L, N)
585
+ + pt .batched_dot ((beta @ gamma @ pt .matrix_transpose (beta )), g )
586
+ )
585
587
586
588
phi = pt .matrix_transpose (
587
589
# (L, N, 1)
0 commit comments