2020
2121from collections .abc import Callable
2222from concurrent .futures import ProcessPoolExecutor , as_completed
23+ from typing import Literal
2324
2425import arviz as az
2526import blackjax
@@ -68,7 +69,7 @@ def add_path_data(self, path_id: int, samples, logP, logQ):
6869
6970def get_jaxified_logp_of_ravel_inputs (
7071 model : Model ,
71- ) -> tuple [ Callable , DictToArrayBijection ] :
72+ ) -> Callable :
7273 """
7374 Get jaxified logp function and ravel inputs for a PyMC model.
7475
@@ -198,7 +199,12 @@ def _get_s_xi_z_xi(x, g, update_mask, J):
198199 return s_xi , z_xi
199200
200201
201- def alpha_recover (x , g ):
202+ def alpha_recover (x , g , epsilon : float = 1e-11 ):
203+ """
204+ epsilon: float
205+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
206+ """
207+
202208 def compute_alpha_l (alpha_lm1 , s_l , z_l ):
203209 # alpha_lm1: (N,)
204210 # s_l: (N,)
@@ -227,7 +233,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
227233 S , Z = _get_delta_x_delta_g (x , g )
228234 alpha_l_init = pt .ones (N )
229235 SZ = (S * Z ).sum (axis = - 1 )
230- update_mask = SZ > 1e-11 * pt .linalg .norm (Z , axis = - 1 )
236+
237+ # Q: Line 5 of Algorithm 3 in Zhang et al., (2022) sets SZ < 1e-11 * L2(Z) as opposed to the ">" sign
238+ update_mask = SZ > epsilon * pt .linalg .norm (Z , axis = - 1 )
231239
232240 alpha , _ = pytensor .scan (
233241 fn = scan_body ,
@@ -289,23 +297,8 @@ def inverse_hessian_factors(alpha, x, g, update_mask, J):
289297 return beta , gamma
290298
291299
292- def _batched (x , g , alpha , beta , gamma ):
293- var_list = [x , g , alpha , beta , gamma ]
294- ndims = np .array ([2 , 2 , 2 , 3 , 3 ])
295- var_ndims = np .array ([var .ndim for var in var_list ])
296-
297- if all (var_ndims == ndims ):
298- return True
299- elif all (var_ndims == ndims - 1 ):
300- return False
301- else :
302- raise ValueError (
303- "All variables must have the same number of dimensions, either matching ndims or ndims - 1."
304- )
305-
306-
307300def bfgs_sample (
308- num_samples ,
301+ num_samples : int ,
309302 x , # position
310303 g , # grad
311304 alpha ,
@@ -326,7 +319,19 @@ def bfgs_sample(
326319
327320 rng = pytensor .shared (np .random .default_rng (seed = random_seed ))
328321
329- if not _batched (x , g , alpha , beta , gamma ):
322+ def batched (x , g , alpha , beta , gamma ):
323+ var_list = [x , g , alpha , beta , gamma ]
324+ ndims = np .array ([2 , 2 , 2 , 3 , 3 ])
325+ var_ndims = np .array ([var .ndim for var in var_list ])
326+
327+ if np .all (var_ndims == ndims ):
328+ return True
329+ elif np .all (var_ndims == ndims - 1 ):
330+ return False
331+ else :
332+ raise ValueError ("Incorrect number of dimensions." )
333+
334+ if not batched (x , g , alpha , beta , gamma ):
330335 x = pt .atleast_2d (x )
331336 g = pt .atleast_2d (g )
332337 alpha = pt .atleast_2d (alpha )
@@ -372,26 +377,23 @@ def bfgs_sample(
372377
373378
374379def compute_logp (logp_func , arr ):
375- """
376- **IMPORTANT**
377- replace nan with -np.inf otherwise np.argmax(elbo) will return you the first index at nan!!!!
378- """
379-
380380 logP = np .apply_along_axis (logp_func , axis = - 1 , arr = arr )
381+ # replace nan with -inf since np.argmax will return the first index at nan
381382 return np .where (np .isnan (logP ), - np .inf , logP )
382383
383384
384385def single_pathfinder (
385386 model ,
386387 num_draws : int ,
387388 maxcor : int | None = None ,
388- maxiter = 1000 ,
389- ftol = 1e-10 ,
390- gtol = 1e-16 ,
391- maxls = 1000 ,
389+ maxiter : int = 1000 ,
390+ ftol : float = 1e-10 ,
391+ gtol : float = 1e-16 ,
392+ maxls : int = 1000 ,
392393 num_elbo_draws : int = 10 ,
393- random_seed : RandomSeed = None ,
394394 jitter : float = 2.0 ,
395+ epsilon : float = 1e-11 ,
396+ random_seed : RandomSeed | None = None ,
395397):
396398 jitter_seed , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
397399 logp_func , dlogp_func = get_logp_dlogp_of_ravel_inputs (model )
@@ -423,7 +425,7 @@ def neg_dlogp_func(x):
423425 maxls = maxls ,
424426 )
425427
426- alpha , update_mask = alpha_recover (history .x , history .g )
428+ alpha , update_mask = alpha_recover (history .x , history .g , epsilon = epsilon )
427429
428430 beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
429431
@@ -486,6 +488,10 @@ def make_initial_pathfinder_point(
486488 DictToArrayBijection
487489 bijection containing jittered initial point
488490 """
491+
492+ # TODO: replace rng.uniform (pseudo random sequence) with scipy.stats.qmc.Sobol (quasi-random sequence)
493+ # Sobol is a better low discrepancy sequence than uniform.
494+
489495 ipfn = make_initial_point_fn (
490496 model = model ,
491497 )
@@ -498,7 +504,7 @@ def make_initial_pathfinder_point(
498504 return ip_map
499505
500506
501- def _run_single_pathfinder (model , path_id , random_seed , ** kwargs ):
507+ def _run_single_pathfinder (model , path_id : int , random_seed : RandomSeed , ** kwargs ):
502508 """Helper to run single pathfinder instance"""
503509 try :
504510 # Handle pickling
@@ -553,13 +559,13 @@ def process_multipath_pathfinder_results(
553559 processed samples, logP and logQ arrays
554560 """
555561 # path[samples]: (I, M, N)
556- num_dims = results .paths [0 ]["samples" ].shape [- 1 ]
562+ N = results .paths [0 ]["samples" ].shape [- 1 ]
557563
558564 paths_array = np .array ([results .paths [i ] for i in range (results .num_paths )])
559565 logP = np .concatenate ([path ["logP" ] for path in paths_array ])
560566 logQ = np .concatenate ([path ["logQ" ] for path in paths_array ])
561567 samples = np .concatenate ([path ["samples" ] for path in paths_array ])
562- samples = samples .reshape (- 1 , num_dims , order = "F" )
568+ samples = samples .reshape (- 1 , N , order = "F" )
563569
564570 # adjust log densities
565571 log_I = np .log (results .num_paths )
@@ -575,12 +581,13 @@ def multipath_pathfinder(
575581 num_draws : int ,
576582 num_draws_per_path : int ,
577583 maxcor : int | None = None ,
578- maxiter = 1000 ,
579- ftol = 1e-10 ,
580- gtol = 1e-16 ,
581- maxls = 1000 ,
584+ maxiter : int = 1000 ,
585+ ftol : float = 1e-10 ,
586+ gtol : float = 1e-16 ,
587+ maxls : int = 1000 ,
582588 num_elbo_draws : int = 10 ,
583589 jitter : float = 2.0 ,
590+ epsilon : float = 1e-11 ,
584591 psis_resample : bool = True ,
585592 random_seed : RandomSeed = None ,
586593 ** pathfinder_kwargs ,
@@ -603,6 +610,7 @@ def multipath_pathfinder(
603610 "maxls" : maxls ,
604611 "num_elbo_draws" : num_elbo_draws ,
605612 "jitter" : jitter ,
613+ "epsilon" : epsilon ,
606614 ** pathfinder_kwargs ,
607615 }
608616 kwargs_pickled = {k : cloudpickle .dumps (v ) for k , v in kwargs .items ()}
@@ -645,20 +653,21 @@ def multipath_pathfinder(
645653
646654def fit_pathfinder (
647655 model ,
648- num_paths = 1 ,
649- num_draws = 1000 ,
650- num_draws_per_path = 1000 ,
651- maxcor = None ,
652- maxiter = 1000 ,
653- ftol = 1e-10 ,
654- gtol = 1e-16 ,
656+ num_paths : int = 1 , # I
657+ num_draws : int = 1000 , # R
658+ num_draws_per_path : int = 1000 , # M
659+ maxcor : int | None = None , # J
660+ maxiter : int = 1000 , # L^max
661+ ftol : float = 1e-10 ,
662+ gtol : float = 1e-16 ,
655663 maxls = 1000 ,
656- num_elbo_draws : int = 10 ,
664+ num_elbo_draws : int = 10 , # K
657665 jitter : float = 2.0 ,
666+ epsilon : float = 1e-11 ,
658667 psis_resample : bool = True ,
659668 random_seed : RandomSeed | None = None ,
660- postprocessing_backend = "cpu" ,
661- inference_backend = "pymc" ,
669+ postprocessing_backend : Literal [ "cpu" , "gpu" ] = "cpu" ,
670+ inference_backend : Literal [ "pymc" , "blackjax" ] = "pymc" ,
662671 ** pathfinder_kwargs ,
663672):
664673 """
@@ -686,11 +695,13 @@ def fit_pathfinder(
686695 gtol : float, optional
687696 Tolerance for the norm of the gradient (default is 1e-16).
688697 maxls : int, optional
689- Maximum number of line search steps (default is 1000).
698+ Maximum number of line search steps for the L-BFGS algorithm (default is 1000).
690699 num_elbo_draws : int, optional
691700 Number of draws for the Evidence Lower Bound (ELBO) estimation (default is 10).
692701 jitter : float, optional
693702 Amount of jitter to apply to initial points (default is 2.0).
703+ epsilon: float
704+ value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-11).
694705 psis_resample : bool, optional
695706 Whether to apply Pareto Smoothed Importance Sampling Resampling (default is True). If false, the samples are returned as is (i.e. no resampling is applied) of the size num_draws_per_path * num_paths.
696707 random_seed : RandomSeed, optional
@@ -733,6 +744,7 @@ def fit_pathfinder(
733744 maxls = maxls ,
734745 num_elbo_draws = num_elbo_draws ,
735746 jitter = jitter ,
747+ epsilon = epsilon ,
736748 psis_resample = psis_resample ,
737749 random_seed = random_seed ,
738750 ** pathfinder_kwargs ,
0 commit comments