1919
2020import pymc_extras as pmx
2121
22- from pymc_extras .inference .find_map import find_MAP
22+ from pymc_extras .inference .find_map import GradientBackend , find_MAP
2323from pymc_extras .inference .laplace import (
2424 fit_laplace ,
2525 fit_mvn_at_MAP ,
@@ -37,7 +37,11 @@ def rng():
3737 "ignore:hessian will stop negating the output in a future version of PyMC.\n "
3838 + "To suppress this warning set `negate_output=False`:FutureWarning" ,
3939)
40- def test_laplace ():
40+ @pytest .mark .parametrize (
41+ "mode, gradient_backend" ,
42+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
43+ )
44+ def test_laplace (mode , gradient_backend : GradientBackend ):
4145 # Example originates from Bayesian Data Analyses, 3rd Edition
4246 # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
4347 # Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
5559 vars = [mu , logsigma ]
5660
5761 idata = pmx .fit (
58- method = "laplace" , optimize_method = "trust-ncg" , draws = draws , random_seed = 173300 , chains = 1
62+ method = "laplace" ,
63+ optimize_method = "trust-ncg" ,
64+ draws = draws ,
65+ random_seed = 173300 ,
66+ chains = 1 ,
67+ compile_kwargs = {"mode" : mode },
68+ gradient_backend = gradient_backend ,
5969 )
6070
6171 assert idata .posterior ["mu" ].shape == (1 , draws )
@@ -71,7 +81,11 @@ def test_laplace():
7181 np .testing .assert_allclose (idata .fit ["covariance_matrix" ].values , bda_cov , atol = 1e-4 )
7282
7383
74- def test_laplace_only_fit ():
84+ @pytest .mark .parametrize (
85+ "mode, gradient_backend" ,
86+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
87+ )
88+ def test_laplace_only_fit (mode , gradient_backend : GradientBackend ):
7589 # Example originates from Bayesian Data Analyses, 3rd Edition
7690 # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
7791 # Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90104 method = "laplace" ,
91105 optimize_method = "BFGS" ,
92106 progressbar = True ,
93- gradient_backend = "jax" ,
94- compile_kwargs = {"mode" : "JAX" },
107+ gradient_backend = gradient_backend ,
108+ compile_kwargs = {"mode" : mode },
95109 optimizer_kwargs = dict (maxiter = 100_000 , gtol = 1e-100 ),
96110 random_seed = 173300 ,
97111 )
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111125 [True , False ],
112126 ids = ["transformed" , "untransformed" ],
113127)
114- @pytest .mark .parametrize ("mode" , ["JAX" , None ], ids = ["jax" , "pytensor" ])
115- def test_fit_laplace_coords (rng , transform_samples , mode ):
128+ @pytest .mark .parametrize (
129+ "mode, gradient_backend" ,
130+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
131+ )
132+ def test_fit_laplace_coords (rng , transform_samples , mode , gradient_backend : GradientBackend ):
116133 coords = {"city" : ["A" , "B" , "C" ], "obs_idx" : np .arange (100 )}
117134 with pm .Model (coords = coords ) as model :
118135 mu = pm .Normal ("mu" , mu = 3 , sigma = 0.5 , dims = ["city" ])
@@ -131,7 +148,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131148 use_hessp = True ,
132149 progressbar = False ,
133150 compile_kwargs = dict (mode = mode ),
134- gradient_backend = "jax" if mode == "JAX" else "pytensor" ,
151+ gradient_backend = gradient_backend ,
135152 )
136153
137154 for value in optimized_point .values ():
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163180 ]
164181
165182
166- def test_fit_laplace_ragged_coords (rng ):
183+ @pytest .mark .parametrize (
184+ "mode, gradient_backend" ,
185+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
186+ )
187+ def test_fit_laplace_ragged_coords (mode , gradient_backend : GradientBackend , rng ):
167188 coords = {"city" : ["A" , "B" , "C" ], "feature" : [0 , 1 ], "obs_idx" : np .arange (100 )}
168189 with pm .Model (coords = coords ) as ragged_dim_model :
169190 X = pm .Data ("X" , np .ones ((100 , 2 )), dims = ["obs_idx" , "feature" ])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188209 progressbar = False ,
189210 use_grad = True ,
190211 use_hessp = True ,
191- gradient_backend = "jax" ,
192- compile_kwargs = {"mode" : "JAX" },
212+ gradient_backend = gradient_backend ,
213+ compile_kwargs = {"mode" : mode },
193214 )
194215
195216 assert idata ["posterior" ].beta .shape [- 2 :] == (3 , 2 )
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206227 [True , False ],
207228 ids = ["transformed" , "untransformed" ],
208229)
209- def test_fit_laplace (fit_in_unconstrained_space ):
230+ @pytest .mark .parametrize (
231+ "mode, gradient_backend" ,
232+ [(None , "pytensor" ), ("NUMBA" , "pytensor" ), ("JAX" , "jax" ), ("JAX" , "pytensor" )],
233+ )
234+ def test_fit_laplace (fit_in_unconstrained_space , mode , gradient_backend : GradientBackend ):
210235 with pm .Model () as simp_model :
211236 mu = pm .Normal ("mu" , mu = 3 , sigma = 0.5 )
212237 sigma = pm .Exponential ("sigma" , 1 )
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223248 use_hessp = True ,
224249 fit_in_unconstrained_space = fit_in_unconstrained_space ,
225250 optimizer_kwargs = dict (maxiter = 100_000 , tol = 1e-100 ),
251+ compile_kwargs = {"mode" : mode },
252+ gradient_backend = gradient_backend ,
226253 )
227254
228255 np .testing .assert_allclose (np .mean (idata .posterior .mu , axis = 1 ), np .full ((2 ,), 3 ), atol = 0.1 )
0 commit comments