Skip to content

Commit 2efb511

Browse files
committed
Added type hints and epsilon parameter to fit_pathfinder
1 parent cb4436c commit 2efb511

File tree

2 files changed

+61
-74
lines changed

2 files changed

+61
-74
lines changed

pymc_experimental/inference/pathfinder/pathfinder.py

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from collections.abc import Callable
2222
from concurrent.futures import ProcessPoolExecutor, as_completed
23+
from typing import Literal
2324

2425
import arviz as az
2526
import blackjax
@@ -68,7 +69,7 @@ def add_path_data(self, path_id: int, samples, logP, logQ):
6869

6970
def get_jaxified_logp_of_ravel_inputs(
7071
model: Model,
71-
) -> tuple[Callable, DictToArrayBijection]:
72+
) -> Callable:
7273
"""
7374
Get jaxified logp function and ravel inputs for a PyMC model.
7475
@@ -198,7 +199,12 @@ def _get_s_xi_z_xi(x, g, update_mask, J):
198199
return s_xi, z_xi
199200

200201

201-
def alpha_recover(x, g):
202+
def alpha_recover(x, g, epsilon: float = 1e-11):
203+
"""
204+
epsilon: float
205+
value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
206+
"""
207+
202208
def compute_alpha_l(alpha_lm1, s_l, z_l):
203209
# alpha_lm1: (N,)
204210
# s_l: (N,)
@@ -227,7 +233,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
227233
S, Z = _get_delta_x_delta_g(x, g)
228234
alpha_l_init = pt.ones(N)
229235
SZ = (S * Z).sum(axis=-1)
230-
update_mask = SZ > 1e-11 * pt.linalg.norm(Z, axis=-1)
236+
237+
# Q: Line 5 of Algorithm 3 in Zhang et al., (2022) sets SZ < 1e-11 * L2(Z) as opposed to the ">" sign
238+
update_mask = SZ > epsilon * pt.linalg.norm(Z, axis=-1)
231239

232240
alpha, _ = pytensor.scan(
233241
fn=scan_body,
@@ -289,23 +297,8 @@ def inverse_hessian_factors(alpha, x, g, update_mask, J):
289297
return beta, gamma
290298

291299

292-
def _batched(x, g, alpha, beta, gamma):
293-
var_list = [x, g, alpha, beta, gamma]
294-
ndims = np.array([2, 2, 2, 3, 3])
295-
var_ndims = np.array([var.ndim for var in var_list])
296-
297-
if all(var_ndims == ndims):
298-
return True
299-
elif all(var_ndims == ndims - 1):
300-
return False
301-
else:
302-
raise ValueError(
303-
"All variables must have the same number of dimensions, either matching ndims or ndims - 1."
304-
)
305-
306-
307300
def bfgs_sample(
308-
num_samples,
301+
num_samples: int,
309302
x, # position
310303
g, # grad
311304
alpha,
@@ -326,7 +319,19 @@ def bfgs_sample(
326319

327320
rng = pytensor.shared(np.random.default_rng(seed=random_seed))
328321

329-
if not _batched(x, g, alpha, beta, gamma):
322+
def batched(x, g, alpha, beta, gamma):
323+
var_list = [x, g, alpha, beta, gamma]
324+
ndims = np.array([2, 2, 2, 3, 3])
325+
var_ndims = np.array([var.ndim for var in var_list])
326+
327+
if np.all(var_ndims == ndims):
328+
return True
329+
elif np.all(var_ndims == ndims - 1):
330+
return False
331+
else:
332+
raise ValueError("Incorrect number of dimensions.")
333+
334+
if not batched(x, g, alpha, beta, gamma):
330335
x = pt.atleast_2d(x)
331336
g = pt.atleast_2d(g)
332337
alpha = pt.atleast_2d(alpha)
@@ -372,26 +377,23 @@ def bfgs_sample(
372377

373378

374379
def compute_logp(logp_func, arr):
375-
"""
376-
**IMPORTANT**
377-
replace nan with -np.inf otherwise np.argmax(elbo) will return you the first index at nan!!!!
378-
"""
379-
380380
logP = np.apply_along_axis(logp_func, axis=-1, arr=arr)
381+
# replace nan with -inf since np.argmax will return the first index at nan
381382
return np.where(np.isnan(logP), -np.inf, logP)
382383

383384

384385
def single_pathfinder(
385386
model,
386387
num_draws: int,
387388
maxcor: int | None = None,
388-
maxiter=1000,
389-
ftol=1e-10,
390-
gtol=1e-16,
391-
maxls=1000,
389+
maxiter: int = 1000,
390+
ftol: float = 1e-10,
391+
gtol: float = 1e-16,
392+
maxls: int = 1000,
392393
num_elbo_draws: int = 10,
393-
random_seed: RandomSeed = None,
394394
jitter: float = 2.0,
395+
epsilon: float = 1e-11,
396+
random_seed: RandomSeed | None = None,
395397
):
396398
jitter_seed, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
397399
logp_func, dlogp_func = get_logp_dlogp_of_ravel_inputs(model)
@@ -423,7 +425,7 @@ def neg_dlogp_func(x):
423425
maxls=maxls,
424426
)
425427

426-
alpha, update_mask = alpha_recover(history.x, history.g)
428+
alpha, update_mask = alpha_recover(history.x, history.g, epsilon=epsilon)
427429

428430
beta, gamma = inverse_hessian_factors(alpha, history.x, history.g, update_mask, J=maxcor)
429431

@@ -486,6 +488,10 @@ def make_initial_pathfinder_point(
486488
DictToArrayBijection
487489
bijection containing jittered initial point
488490
"""
491+
492+
# TODO: replace rng.uniform (pseudo random sequence) with scipy.stats.qmc.Sobol (quasi-random sequence)
493+
# Sobol is a better low discrepancy sequence than uniform.
494+
489495
ipfn = make_initial_point_fn(
490496
model=model,
491497
)
@@ -498,7 +504,7 @@ def make_initial_pathfinder_point(
498504
return ip_map
499505

500506

501-
def _run_single_pathfinder(model, path_id, random_seed, **kwargs):
507+
def _run_single_pathfinder(model, path_id: int, random_seed: RandomSeed, **kwargs):
502508
"""Helper to run single pathfinder instance"""
503509
try:
504510
# Handle pickling
@@ -553,13 +559,13 @@ def process_multipath_pathfinder_results(
553559
processed samples, logP and logQ arrays
554560
"""
555561
# path[samples]: (I, M, N)
556-
num_dims = results.paths[0]["samples"].shape[-1]
562+
N = results.paths[0]["samples"].shape[-1]
557563

558564
paths_array = np.array([results.paths[i] for i in range(results.num_paths)])
559565
logP = np.concatenate([path["logP"] for path in paths_array])
560566
logQ = np.concatenate([path["logQ"] for path in paths_array])
561567
samples = np.concatenate([path["samples"] for path in paths_array])
562-
samples = samples.reshape(-1, num_dims, order="F")
568+
samples = samples.reshape(-1, N, order="F")
563569

564570
# adjust log densities
565571
log_I = np.log(results.num_paths)
@@ -575,12 +581,13 @@ def multipath_pathfinder(
575581
num_draws: int,
576582
num_draws_per_path: int,
577583
maxcor: int | None = None,
578-
maxiter=1000,
579-
ftol=1e-10,
580-
gtol=1e-16,
581-
maxls=1000,
584+
maxiter: int = 1000,
585+
ftol: float = 1e-10,
586+
gtol: float = 1e-16,
587+
maxls: int = 1000,
582588
num_elbo_draws: int = 10,
583589
jitter: float = 2.0,
590+
epsilon: float = 1e-11,
584591
psis_resample: bool = True,
585592
random_seed: RandomSeed = None,
586593
**pathfinder_kwargs,
@@ -603,6 +610,7 @@ def multipath_pathfinder(
603610
"maxls": maxls,
604611
"num_elbo_draws": num_elbo_draws,
605612
"jitter": jitter,
613+
"epsilon": epsilon,
606614
**pathfinder_kwargs,
607615
}
608616
kwargs_pickled = {k: cloudpickle.dumps(v) for k, v in kwargs.items()}
@@ -645,20 +653,21 @@ def multipath_pathfinder(
645653

646654
def fit_pathfinder(
647655
model,
648-
num_paths=1,
649-
num_draws=1000,
650-
num_draws_per_path=1000,
651-
maxcor=None,
652-
maxiter=1000,
653-
ftol=1e-10,
654-
gtol=1e-16,
656+
num_paths: int = 1, # I
657+
num_draws: int = 1000, # R
658+
num_draws_per_path: int = 1000, # M
659+
maxcor: int | None = None, # J
660+
maxiter: int = 1000, # L^max
661+
ftol: float = 1e-10,
662+
gtol: float = 1e-16,
655663
maxls=1000,
656-
num_elbo_draws: int = 10,
664+
num_elbo_draws: int = 10, # K
657665
jitter: float = 2.0,
666+
epsilon: float = 1e-11,
658667
psis_resample: bool = True,
659668
random_seed: RandomSeed | None = None,
660-
postprocessing_backend="cpu",
661-
inference_backend="pymc",
669+
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
670+
inference_backend: Literal["pymc", "blackjax"] = "pymc",
662671
**pathfinder_kwargs,
663672
):
664673
"""
@@ -686,11 +695,13 @@ def fit_pathfinder(
686695
gtol : float, optional
687696
Tolerance for the norm of the gradient (default is 1e-16).
688697
maxls : int, optional
689-
Maximum number of line search steps (default is 1000).
698+
Maximum number of line search steps for the L-BFGS algorithm (default is 1000).
690699
num_elbo_draws : int, optional
691700
Number of draws for the Evidence Lower Bound (ELBO) estimation (default is 10).
692701
jitter : float, optional
693702
Amount of jitter to apply to initial points (default is 2.0).
703+
epsilon: float
704+
value used to filter out large changes in the direction of the update gradient at each iteration l in L. iteration l are only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-11).
694705
psis_resample : bool, optional
695706
Whether to apply Pareto Smoothed Importance Sampling Resampling (default is True). If false, the samples are returned as is (i.e. no resampling is applied) of the size num_draws_per_path * num_paths.
696707
random_seed : RandomSeed, optional
@@ -733,6 +744,7 @@ def fit_pathfinder(
733744
maxls=maxls,
734745
num_elbo_draws=num_elbo_draws,
735746
jitter=jitter,
747+
epsilon=epsilon,
736748
psis_resample=psis_resample,
737749
random_seed=random_seed,
738750
**pathfinder_kwargs,

tests/test_pathfinder.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,3 @@ def test_process_multipath_results():
136136
assert samples.shape == (num_paths * num_draws, num_dims)
137137
assert logP.shape == (num_paths * num_draws,)
138138
assert logQ.shape == (num_paths * num_draws,)
139-
140-
141-
def test_pathfinder_results():
142-
"""Test PathfinderResults class"""
143-
from pymc_experimental.inference.pathfinder.pathfinder import PathfinderResults
144-
145-
num_paths = 3
146-
num_draws = 100
147-
num_dims = 2
148-
149-
results = PathfinderResults(num_paths, num_draws, num_dims)
150-
151-
# Test initialization
152-
assert len(results.paths) == num_paths
153-
assert results.paths[0]["samples"].shape == (num_draws, num_dims)
154-
155-
# Test adding data
156-
samples = np.random.randn(num_draws, num_dims)
157-
logP = np.random.randn(num_draws)
158-
logQ = np.random.randn(num_draws)
159-
160-
results.add_path_data(0, samples, logP, logQ)
161-
np.testing.assert_array_equal(results.paths[0]["samples"], samples)
162-
np.testing.assert_array_equal(results.paths[0]["logP"], logP)
163-
np.testing.assert_array_equal(results.paths[0]["logQ"], logQ)

0 commit comments

Comments
 (0)