Skip to content

Commit b876114

Browse files
Add more backend tests to fit_MAP/laplace
1 parent 1cef2e7 commit b876114

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

pymc_extras/inference/find_map.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,26 @@ def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
3232

3333
if use_hess and use_hessp:
3434
_log.warning(
35-
'Both "use_hess" and "use_hessp" are set to True. scipy.optimize.minimize never uses both at the '
36-
'same time. Setting "use_hess" to False.'
35+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37+
'Setting "use_hess" to False.'
3738
)
3839
use_hess = False
3940

4041
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
41-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
42-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
42+
43+
if use_hessp is not None and use_hess is None:
44+
use_hess = not use_hessp
45+
46+
elif use_hess is not None and use_hessp is None:
47+
use_hessp = not use_hess
48+
49+
elif use_hessp is None and use_hess is None:
50+
use_hessp = method_info["uses_hessp"]
51+
use_hess = method_info["uses_hess"]
52+
if use_hessp and use_hess:
53+
# If a method could use either hess or hessp, we default to using hessp
54+
use_hess = False
4355

4456
return use_grad, use_hess, use_hessp
4557

@@ -63,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
6375
The nearest positive semi-definite matrix to the input matrix.
6476
"""
6577
C = (A + A.T) / 2
66-
eigval, eigvec = np.linalg.eig(C)
78+
eigval, eigvec = np.linalg.eigh(C)
6779
eigval[eigval < 0] = 0
6880

6981
return eigvec @ np.diag(eigval) @ eigvec.T

tests/test_laplace.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pymc_extras as pmx
2121

22-
from pymc_extras.inference.find_map import find_MAP
22+
from pymc_extras.inference.find_map import GradientBackend, find_MAP
2323
from pymc_extras.inference.laplace import (
2424
fit_laplace,
2525
fit_mvn_at_MAP,
@@ -37,7 +37,11 @@ def rng():
3737
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
3838
+ "To suppress this warning set `negate_output=False`:FutureWarning",
3939
)
40-
def test_laplace():
40+
@pytest.mark.parametrize(
41+
"mode, gradient_backend",
42+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
43+
)
44+
def test_laplace(mode, gradient_backend: GradientBackend):
4145
# Example originates from Bayesian Data Analyses, 3rd Edition
4246
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
4347
# Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
5559
vars = [mu, logsigma]
5660

5761
idata = pmx.fit(
58-
method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1
62+
method="laplace",
63+
optimize_method="trust-ncg",
64+
draws=draws,
65+
random_seed=173300,
66+
chains=1,
67+
compile_kwargs={"mode": mode},
68+
gradient_backend=gradient_backend,
5969
)
6070

6171
assert idata.posterior["mu"].shape == (1, draws)
@@ -71,7 +81,11 @@ def test_laplace():
7181
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
7282

7383

74-
def test_laplace_only_fit():
84+
@pytest.mark.parametrize(
85+
"mode, gradient_backend",
86+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
87+
)
88+
def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
7589
# Example originates from Bayesian Data Analyses, 3rd Edition
7690
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
7791
# Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90104
method="laplace",
91105
optimize_method="BFGS",
92106
progressbar=True,
93-
gradient_backend="jax",
94-
compile_kwargs={"mode": "JAX"},
107+
gradient_backend=gradient_backend,
108+
compile_kwargs={"mode": mode},
95109
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
96110
random_seed=173300,
97111
)
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111125
[True, False],
112126
ids=["transformed", "untransformed"],
113127
)
114-
@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"])
115-
def test_fit_laplace_coords(rng, transform_samples, mode):
128+
@pytest.mark.parametrize(
129+
"mode, gradient_backend",
130+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
131+
)
132+
def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
116133
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
117134
with pm.Model(coords=coords) as model:
118135
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -131,7 +148,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131148
use_hessp=True,
132149
progressbar=False,
133150
compile_kwargs=dict(mode=mode),
134-
gradient_backend="jax" if mode == "JAX" else "pytensor",
151+
gradient_backend=gradient_backend,
135152
)
136153

137154
for value in optimized_point.values():
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163180
]
164181

165182

166-
def test_fit_laplace_ragged_coords(rng):
183+
@pytest.mark.parametrize(
184+
"mode, gradient_backend",
185+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
186+
)
187+
def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
167188
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
168189
with pm.Model(coords=coords) as ragged_dim_model:
169190
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188209
progressbar=False,
189210
use_grad=True,
190211
use_hessp=True,
191-
gradient_backend="jax",
192-
compile_kwargs={"mode": "JAX"},
212+
gradient_backend=gradient_backend,
213+
compile_kwargs={"mode": mode},
193214
)
194215

195216
assert idata["posterior"].beta.shape[-2:] == (3, 2)
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206227
[True, False],
207228
ids=["transformed", "untransformed"],
208229
)
209-
def test_fit_laplace(fit_in_unconstrained_space):
230+
@pytest.mark.parametrize(
231+
"mode, gradient_backend",
232+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
233+
)
234+
def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
210235
with pm.Model() as simp_model:
211236
mu = pm.Normal("mu", mu=3, sigma=0.5)
212237
sigma = pm.Exponential("sigma", 1)
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223248
use_hessp=True,
224249
fit_in_unconstrained_space=fit_in_unconstrained_space,
225250
optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
251+
compile_kwargs={"mode": mode},
252+
gradient_backend=gradient_backend,
226253
)
227254

228255
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)

0 commit comments

Comments
 (0)