20
20
21
21
from collections .abc import Callable
22
22
from concurrent .futures import ProcessPoolExecutor , as_completed
23
+ from typing import Literal
23
24
24
25
import arviz as az
25
26
import blackjax
@@ -68,7 +69,7 @@ def add_path_data(self, path_id: int, samples, logP, logQ):
68
69
69
70
def get_jaxified_logp_of_ravel_inputs (
70
71
model : Model ,
71
- ) -> tuple [ Callable , DictToArrayBijection ] :
72
+ ) -> Callable :
72
73
"""
73
74
Get jaxified logp function and ravel inputs for a PyMC model.
74
75
@@ -198,7 +199,12 @@ def _get_s_xi_z_xi(x, g, update_mask, J):
198
199
return s_xi , z_xi
199
200
200
201
201
- def alpha_recover (x , g ):
202
+ def alpha_recover (x , g , epsilon : float = 1e-11 ):
203
+ """
204
+ epsilon: float
205
+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
206
+ """
207
+
202
208
def compute_alpha_l (alpha_lm1 , s_l , z_l ):
203
209
# alpha_lm1: (N,)
204
210
# s_l: (N,)
@@ -227,7 +233,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
227
233
S , Z = _get_delta_x_delta_g (x , g )
228
234
alpha_l_init = pt .ones (N )
229
235
SZ = (S * Z ).sum (axis = - 1 )
230
- update_mask = SZ > 1e-11 * pt .linalg .norm (Z , axis = - 1 )
236
+
237
+ # Q: Line 5 of Algorithm 3 in Zhang et al., (2022) sets SZ < 1e-11 * L2(Z) as opposed to the ">" sign
238
+ update_mask = SZ > epsilon * pt .linalg .norm (Z , axis = - 1 )
231
239
232
240
alpha , _ = pytensor .scan (
233
241
fn = scan_body ,
@@ -289,23 +297,8 @@ def inverse_hessian_factors(alpha, x, g, update_mask, J):
289
297
return beta , gamma
290
298
291
299
292
- def _batched (x , g , alpha , beta , gamma ):
293
- var_list = [x , g , alpha , beta , gamma ]
294
- ndims = np .array ([2 , 2 , 2 , 3 , 3 ])
295
- var_ndims = np .array ([var .ndim for var in var_list ])
296
-
297
- if all (var_ndims == ndims ):
298
- return True
299
- elif all (var_ndims == ndims - 1 ):
300
- return False
301
- else :
302
- raise ValueError (
303
- "All variables must have the same number of dimensions, either matching ndims or ndims - 1."
304
- )
305
-
306
-
307
300
def bfgs_sample (
308
- num_samples ,
301
+ num_samples : int ,
309
302
x , # position
310
303
g , # grad
311
304
alpha ,
@@ -326,7 +319,19 @@ def bfgs_sample(
326
319
327
320
rng = pytensor .shared (np .random .default_rng (seed = random_seed ))
328
321
329
- if not _batched (x , g , alpha , beta , gamma ):
322
+ def batched (x , g , alpha , beta , gamma ):
323
+ var_list = [x , g , alpha , beta , gamma ]
324
+ ndims = np .array ([2 , 2 , 2 , 3 , 3 ])
325
+ var_ndims = np .array ([var .ndim for var in var_list ])
326
+
327
+ if np .all (var_ndims == ndims ):
328
+ return True
329
+ elif np .all (var_ndims == ndims - 1 ):
330
+ return False
331
+ else :
332
+ raise ValueError ("Incorrect number of dimensions." )
333
+
334
+ if not batched (x , g , alpha , beta , gamma ):
330
335
x = pt .atleast_2d (x )
331
336
g = pt .atleast_2d (g )
332
337
alpha = pt .atleast_2d (alpha )
@@ -372,26 +377,23 @@ def bfgs_sample(
372
377
373
378
374
379
def compute_logp (logp_func , arr ):
375
- """
376
- **IMPORTANT**
377
- replace nan with -np.inf otherwise np.argmax(elbo) will return you the first index at nan!!!!
378
- """
379
-
380
380
logP = np .apply_along_axis (logp_func , axis = - 1 , arr = arr )
381
+ # replace nan with -inf since np.argmax will return the first index at nan
381
382
return np .where (np .isnan (logP ), - np .inf , logP )
382
383
383
384
384
385
def single_pathfinder (
385
386
model ,
386
387
num_draws : int ,
387
388
maxcor : int | None = None ,
388
- maxiter = 1000 ,
389
- ftol = 1e-10 ,
390
- gtol = 1e-16 ,
391
- maxls = 1000 ,
389
+ maxiter : int = 1000 ,
390
+ ftol : float = 1e-10 ,
391
+ gtol : float = 1e-16 ,
392
+ maxls : int = 1000 ,
392
393
num_elbo_draws : int = 10 ,
393
- random_seed : RandomSeed = None ,
394
394
jitter : float = 2.0 ,
395
+ epsilon : float = 1e-11 ,
396
+ random_seed : RandomSeed | None = None ,
395
397
):
396
398
jitter_seed , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
397
399
logp_func , dlogp_func = get_logp_dlogp_of_ravel_inputs (model )
@@ -423,7 +425,7 @@ def neg_dlogp_func(x):
423
425
maxls = maxls ,
424
426
)
425
427
426
- alpha , update_mask = alpha_recover (history .x , history .g )
428
+ alpha , update_mask = alpha_recover (history .x , history .g , epsilon = epsilon )
427
429
428
430
beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
429
431
@@ -486,6 +488,10 @@ def make_initial_pathfinder_point(
486
488
DictToArrayBijection
487
489
bijection containing jittered initial point
488
490
"""
491
+
492
+ # TODO: replace rng.uniform (pseudo random sequence) with scipy.stats.qmc.Sobol (quasi-random sequence)
493
+ # Sobol is a better low discrepancy sequence than uniform.
494
+
489
495
ipfn = make_initial_point_fn (
490
496
model = model ,
491
497
)
@@ -498,7 +504,7 @@ def make_initial_pathfinder_point(
498
504
return ip_map
499
505
500
506
501
- def _run_single_pathfinder (model , path_id , random_seed , ** kwargs ):
507
+ def _run_single_pathfinder (model , path_id : int , random_seed : RandomSeed , ** kwargs ):
502
508
"""Helper to run single pathfinder instance"""
503
509
try :
504
510
# Handle pickling
@@ -553,13 +559,13 @@ def process_multipath_pathfinder_results(
553
559
processed samples, logP and logQ arrays
554
560
"""
555
561
# path[samples]: (I, M, N)
556
- num_dims = results .paths [0 ]["samples" ].shape [- 1 ]
562
+ N = results .paths [0 ]["samples" ].shape [- 1 ]
557
563
558
564
paths_array = np .array ([results .paths [i ] for i in range (results .num_paths )])
559
565
logP = np .concatenate ([path ["logP" ] for path in paths_array ])
560
566
logQ = np .concatenate ([path ["logQ" ] for path in paths_array ])
561
567
samples = np .concatenate ([path ["samples" ] for path in paths_array ])
562
- samples = samples .reshape (- 1 , num_dims , order = "F" )
568
+ samples = samples .reshape (- 1 , N , order = "F" )
563
569
564
570
# adjust log densities
565
571
log_I = np .log (results .num_paths )
@@ -575,12 +581,13 @@ def multipath_pathfinder(
575
581
num_draws : int ,
576
582
num_draws_per_path : int ,
577
583
maxcor : int | None = None ,
578
- maxiter = 1000 ,
579
- ftol = 1e-10 ,
580
- gtol = 1e-16 ,
581
- maxls = 1000 ,
584
+ maxiter : int = 1000 ,
585
+ ftol : float = 1e-10 ,
586
+ gtol : float = 1e-16 ,
587
+ maxls : int = 1000 ,
582
588
num_elbo_draws : int = 10 ,
583
589
jitter : float = 2.0 ,
590
+ epsilon : float = 1e-11 ,
584
591
psis_resample : bool = True ,
585
592
random_seed : RandomSeed = None ,
586
593
** pathfinder_kwargs ,
@@ -603,6 +610,7 @@ def multipath_pathfinder(
603
610
"maxls" : maxls ,
604
611
"num_elbo_draws" : num_elbo_draws ,
605
612
"jitter" : jitter ,
613
+ "epsilon" : epsilon ,
606
614
** pathfinder_kwargs ,
607
615
}
608
616
kwargs_pickled = {k : cloudpickle .dumps (v ) for k , v in kwargs .items ()}
@@ -645,20 +653,21 @@ def multipath_pathfinder(
645
653
646
654
def fit_pathfinder (
647
655
model ,
648
- num_paths = 1 ,
649
- num_draws = 1000 ,
650
- num_draws_per_path = 1000 ,
651
- maxcor = None ,
652
- maxiter = 1000 ,
653
- ftol = 1e-10 ,
654
- gtol = 1e-16 ,
656
+ num_paths : int = 1 , # I
657
+ num_draws : int = 1000 , # R
658
+ num_draws_per_path : int = 1000 , # M
659
+ maxcor : int | None = None , # J
660
+ maxiter : int = 1000 , # L^max
661
+ ftol : float = 1e-10 ,
662
+ gtol : float = 1e-16 ,
655
663
maxls = 1000 ,
656
- num_elbo_draws : int = 10 ,
664
+ num_elbo_draws : int = 10 , # K
657
665
jitter : float = 2.0 ,
666
+ epsilon : float = 1e-11 ,
658
667
psis_resample : bool = True ,
659
668
random_seed : RandomSeed | None = None ,
660
- postprocessing_backend = "cpu" ,
661
- inference_backend = "pymc" ,
669
+ postprocessing_backend : Literal [ "cpu" , "gpu" ] = "cpu" ,
670
+ inference_backend : Literal [ "pymc" , "blackjax" ] = "pymc" ,
662
671
** pathfinder_kwargs ,
663
672
):
664
673
"""
@@ -686,11 +695,13 @@ def fit_pathfinder(
686
695
gtol : float, optional
687
696
Tolerance for the norm of the gradient (default is 1e-16).
688
697
maxls : int, optional
689
- Maximum number of line search steps (default is 1000).
698
+ Maximum number of line search steps for the L-BFGS algorithm (default is 1000).
690
699
num_elbo_draws : int, optional
691
700
Number of draws for the Evidence Lower Bound (ELBO) estimation (default is 10).
692
701
jitter : float, optional
693
702
Amount of jitter to apply to initial points (default is 2.0).
703
+ epsilon: float
704
+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-11).
694
705
psis_resample : bool, optional
695
706
Whether to apply Pareto Smoothed Importance Sampling Resampling (default is True). If false, the samples are returned as is (i.e. no resampling is applied) of the size num_draws_per_path * num_paths.
696
707
random_seed : RandomSeed, optional
@@ -733,6 +744,7 @@ def fit_pathfinder(
733
744
maxls = maxls ,
734
745
num_elbo_draws = num_elbo_draws ,
735
746
jitter = jitter ,
747
+ epsilon = epsilon ,
736
748
psis_resample = psis_resample ,
737
749
random_seed = random_seed ,
738
750
** pathfinder_kwargs ,
0 commit comments