Skip to content

Commit da338bf

Browse files
- Rename argument use_jax_gradients -> gradient_backend
- Rename function `laplace` -> `sample_laplace_posterior`
1 parent bc340c2 commit da338bf

File tree

4 files changed

+57
-28
lines changed

4 files changed

+57
-28
lines changed

pymc_experimental/inference/find_map.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from collections.abc import Callable
4-
from typing import cast
4+
from typing import Literal, cast, get_args
55

66
import jax
77
import numpy as np
@@ -17,11 +17,15 @@
1717
from pymc.pytensorf import join_nonshared_inputs
1818
from pymc.util import get_default_varnames
1919
from pytensor.compile import Function
20+
from pytensor.compile.mode import Mode
2021
from pytensor.tensor import TensorVariable
2122
from scipy.optimize import OptimizeResult
2223

2324
_log = logging.getLogger(__name__)
2425

26+
GradientBackend = Literal["pytensor", "jax"]
27+
VALID_BACKENDS = get_args(GradientBackend)
28+
2529

2630
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
2731
method_info = MINIMIZE_MODE_KWARGS[method].copy()
@@ -85,7 +89,11 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
8589

8690
out.append(untransformed_X)
8791

88-
f_untransform = pytensor.function([X], out, mode="JAX")
92+
f_untransform = pytensor.function(
93+
inputs=[pytensor.In(X, borrow=True)],
94+
outputs=pytensor.Out(out, borrow=True),
95+
mode=Mode(linker="py", optimizer=None),
96+
)
8997
return f_untransform(posterior_draws)
9098

9199

@@ -209,7 +217,7 @@ def scipy_optimize_funcs_from_loss(
209217
use_grad: bool,
210218
use_hess: bool,
211219
use_hessp: bool,
212-
use_jax_gradients: bool = False,
220+
gradient_backend: GradientBackend = "pytensor",
213221
compile_kwargs: dict | None = None,
214222
) -> tuple[Callable, ...]:
215223
"""
@@ -230,8 +238,8 @@ def scipy_optimize_funcs_from_loss(
230238
Whether to compile a function that computes the Hessian of the loss function.
231239
use_hessp: bool
232240
Whether to compile a function that computes the Hessian-vector product of the loss function.
233-
use_jax_gradients: bool
234-
If True, use JAX to compute gradients. This is only possible when ``compile_kwargs["mode"]`` is set to "JAX".
241+
gradient_backend: str, one of "jax" or "pytensor"
242+
Which backend to use to compute gradients.
235243
compile_kwargs:
236244
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
237245
@@ -252,7 +260,12 @@ def scipy_optimize_funcs_from_loss(
252260
"Cannot compute hessian or hessian-vector product without also computing the gradient"
253261
)
254262

255-
use_jax_gradients = use_jax_gradients and use_grad
263+
if gradient_backend not in VALID_BACKENDS:
264+
raise ValueError(
265+
f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
266+
)
267+
268+
use_jax_gradients = (gradient_backend == "jax") and use_grad
256269

257270
mode = compile_kwargs.get("mode", None)
258271
if mode is None and use_jax_gradients:
@@ -307,7 +320,7 @@ def find_MAP(
307320
jitter_rvs: list[TensorVariable] | None = None,
308321
progressbar: bool = True,
309322
include_transformed: bool = True,
310-
use_jax_gradients: bool = False,
323+
gradient_backend: GradientBackend = "pytensor",
311324
compile_kwargs: dict | None = None,
312325
**optimizer_kwargs,
313326
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
@@ -342,6 +355,10 @@ def find_MAP(
342355
Whether to display a progress bar during optimization. Defaults to True.
343356
include_transformed: bool, optional
344357
Whether to include transformed variable values in the returned dictionary. Defaults to True.
358+
gradient_backend: str, default "pytensor"
359+
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
360+
compile_kwargs: dict, optional
361+
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
345362
**optimizer_kwargs
346363
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
347364
@@ -380,7 +397,7 @@ def find_MAP(
380397
use_grad=use_grad,
381398
use_hess=use_hess,
382399
use_hessp=use_hessp,
383-
use_jax_gradients=use_jax_gradients,
400+
gradient_backend=gradient_backend,
384401
compile_kwargs=compile_kwargs,
385402
)
386403

pymc_experimental/inference/laplace.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from scipy import stats
4242

4343
from pymc_experimental.inference.find_map import (
44+
GradientBackend,
4445
_unconstrained_vector_to_constrained_rvs,
4546
find_MAP,
4647
get_nearest_psd,
@@ -235,7 +236,7 @@ def fit_mvn_to_MAP(
235236
model: pm.Model | None = None,
236237
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
237238
transform_samples: bool = False,
238-
use_jax_gradients: bool = False,
239+
gradient_backend: GradientBackend = "pytensor",
239240
zero_tol: float = 1e-8,
240241
diag_jitter: float | None = 1e-8,
241242
compile_kwargs: dict | None = None,
@@ -256,12 +257,16 @@ def fit_mvn_to_MAP(
256257
If 'error', an error will be raised.
257258
transform_samples : bool
258259
Whether to transform the samples back to the original parameter space. Default is True.
260+
gradient_backend: str, default "pytensor"
261+
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
259262
zero_tol: float
260263
Value below which an element of the Hessian matrix is counted as 0.
261264
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
262265
diag_jitter: float | None
263266
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
264267
If None, no jitter is added. Default is 1e-8.
268+
compile_kwargs: dict, optional
269+
Additional keyword arguments to pass to pytensor.function when compiling loss functions
265270
266271
Returns
267272
-------
@@ -294,7 +299,7 @@ def fit_mvn_to_MAP(
294299
use_grad=True,
295300
use_hess=True,
296301
use_hessp=False,
297-
use_jax_gradients=use_jax_gradients,
302+
gradient_backend=gradient_backend,
298303
compile_kwargs=compile_kwargs,
299304
)
300305

@@ -323,7 +328,7 @@ def stabilize(x, jitter):
323328
return mu, H_inv
324329

325330

326-
def laplace(
331+
def sample_laplace_posterior(
327332
mu: RaveledVars,
328333
H_inv: np.ndarray,
329334
model: pm.Model | None = None,
@@ -416,7 +421,7 @@ def fit_laplace(
416421
jitter_rvs: list[pt.TensorVariable] | None = None,
417422
progressbar: bool = True,
418423
include_transformed: bool = True,
419-
use_jax_gradients: bool = False,
424+
gradient_backend: GradientBackend = "pytensor",
420425
chains: int = 2,
421426
draws: int = 500,
422427
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -461,8 +466,8 @@ def fit_laplace(
461466
Whether to display a progress bar during optimization. Defaults to True.
462467
include_transformed: bool, optional
463468
Whether to include transformed variable values in the returned dictionary. Defaults to True.
464-
use_jax_gradients: bool, optional
465-
Whether to use JAX for gradient calculations. Defaults to False.
469+
gradient_backend: str, default "pytensor"
470+
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
466471
chains: int, default: 2
467472
The number of sampling chains running in parallel.
468473
draws: int, default: 500
@@ -489,7 +494,7 @@ def fit_laplace(
489494
490495
Examples
491496
--------
492-
>>> from pymc_experimental.inference.laplace import fit_laplace
497+
>>> from pymc_experimental.inference.sample_laplace_posterior import fit_laplace
493498
>>> import numpy as np
494499
>>> import pymc as pm
495500
>>> import arviz as az
@@ -526,7 +531,7 @@ def fit_laplace(
526531
jitter_rvs=jitter_rvs,
527532
progressbar=progressbar,
528533
include_transformed=include_transformed,
529-
use_jax_gradients=use_jax_gradients,
534+
gradient_backend=gradient_backend,
530535
compile_kwargs=compile_kwargs,
531536
**optimizer_kwargs,
532537
)
@@ -541,7 +546,7 @@ def fit_laplace(
541546
compile_kwargs=compile_kwargs,
542547
)
543548

544-
return laplace(
549+
return sample_laplace_posterior(
545550
mu=mu,
546551
H_inv=H_inv,
547552
model=model,

tests/test_jax_find_map.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from pymc_experimental.inference.find_map import (
7+
GradientBackend,
78
find_MAP,
89
scipy_optimize_funcs_from_loss,
910
)
@@ -17,8 +18,8 @@ def rng():
1718
return np.random.default_rng(seed)
1819

1920

20-
@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"])
21-
def test_jax_functions_from_graph(use_jax_gradients):
21+
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
22+
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
2223
x = pt.tensor("x", shape=(2,))
2324

2425
def compute_z(x):
@@ -34,7 +35,7 @@ def compute_z(x):
3435
use_grad=True,
3536
use_hess=True,
3637
use_hessp=True,
37-
use_jax_gradients=use_jax_gradients,
38+
gradient_backend=gradient_backend,
3839
compile_kwargs=dict(mode="JAX"),
3940
)
4041

@@ -69,8 +70,8 @@ def compute_z(x):
6970
("trust-constr", True, True),
7071
],
7172
)
72-
@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"])
73-
def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng):
73+
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74+
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
7475
extra_kwargs = {}
7576
if method == "dogleg":
7677
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,7 +89,7 @@ def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng):
8889
use_grad=use_grad,
8990
use_hess=use_hess,
9091
progressbar=False,
91-
use_jax_gradients=use_jax_gradients,
92+
gradient_backend=gradient_backend,
9293
compile_kwargs={"mode": "JAX"},
9394
)
9495
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]

tests/test_laplace.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
import pymc_experimental as pmx
2121

2222
from pymc_experimental.inference.find_map import find_MAP
23-
from pymc_experimental.inference.laplace import fit_laplace, fit_mvn_to_MAP, laplace
23+
from pymc_experimental.inference.laplace import (
24+
fit_laplace,
25+
fit_mvn_to_MAP,
26+
sample_laplace_posterior,
27+
)
2428

2529

2630
@pytest.fixture(scope="session")
@@ -86,7 +90,7 @@ def test_laplace_only_fit():
8690
method="laplace",
8791
optimize_method="BFGS",
8892
progressbar=True,
89-
use_jax_gradients=True,
93+
gradient_backend="jax",
9094
compile_kwargs={"mode": "JAX"},
9195
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
9296
random_seed=173300,
@@ -127,7 +131,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
127131
use_hessp=True,
128132
progressbar=False,
129133
compile_kwargs=dict(mode=mode),
130-
use_jax_gradients=mode == "JAX",
134+
gradient_backend="jax" if mode == "JAX" else "pytensor",
131135
)
132136

133137
for value in optimized_point.values():
@@ -139,7 +143,9 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
139143
transform_samples=transform_samples,
140144
)
141145

142-
idata = laplace(mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples)
146+
idata = sample_laplace_posterior(
147+
mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples
148+
)
143149

144150
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5)
145151
np.testing.assert_allclose(
@@ -182,7 +188,7 @@ def test_fit_laplace_ragged_coords(rng):
182188
progressbar=False,
183189
use_grad=True,
184190
use_hessp=True,
185-
use_jax_gradients=True,
191+
gradient_backend="jax",
186192
compile_kwargs={"mode": "JAX"},
187193
)
188194

0 commit comments

Comments
 (0)